[mip] Add define for number of mip modes. Fix mip cost calculation. If mip is enabled, the cost of writing of mip flag must always be included. Some code cleanup.

This commit is contained in:
siivonek 2022-02-10 02:12:06 +02:00
parent 09f3af81c6
commit ac45a5299c
5 changed files with 33 additions and 36 deletions

View file

@ -725,10 +725,8 @@ void kvz_mip_pred_upsampling_1D(int* const dst, const int* const src, const int*
/** \brief Matrix weighted intra prediction.
*/
// MIP_TODO: remove color parameter if it is not used
void kvz_mip_predict(encoder_state_t const* const state, kvz_intra_references* const refs,
const uint16_t pred_block_width, const uint16_t pred_block_height,
const color_t color,
kvz_pixel* dst,
const int mip_mode, const bool mip_transp)
{

View file

@ -160,7 +160,6 @@ void kvz_mip_predict(
kvz_intra_references * refs,
const uint16_t width,
const uint16_t height,
const color_t color,
kvz_pixel* dst,
const int mip_mode,
const bool mip_transp

View file

@ -504,8 +504,11 @@ static double calc_mode_bits(const encoder_state_t *state,
kvz_intra_get_dir_luma_predictor(x, y, candidate_modes, cur_cu, left_cu, above_cu);
}
// MIP_TODO: calculation of MIP mode cost if this CU has MIP enabled.
double mode_bits = kvz_luma_mode_bits(state, cur_cu->intra.mode, candidate_modes, cur_cu->intra.multi_ref_idx, 0, 0);
int width = LCU_WIDTH >> depth;
int height = width; // TODO: height for non-square blocks
int num_mip_modes_half = NUM_MIP_MODES_HALF(width, height);
int mip_flag_ctx_id = kvz_get_mip_flag_context(x, y, width, height, lcu, NULL);
double mode_bits = kvz_luma_mode_bits(state, cur_cu->intra.mode, candidate_modes, cur_cu->intra.multi_ref_idx, num_mip_modes_half, mip_flag_ctx_id);
if (((depth == 4 && x % 8 && y % 8) || (depth != 4)) && state->encoder_control->chroma_format != KVZ_CSP_400) {
mode_bits += kvz_chroma_mode_bits(state, cur_cu->intra.mode_chroma, cur_cu->intra.mode);

View file

@ -44,6 +44,9 @@
#include "image.h"
#include "constraint.h"
#define NUM_MIP_MODES_FULL(width, height) (width == 4 && height == 4) ? 32 : (width == 4 || height == 4 || (width == 8 && height == 8) ? 16 : 12)
#define NUM_MIP_MODES_HALF(width, height) NUM_MIP_MODES_FULL(width, height) >> 1
void kvz_sort_modes(int8_t *__restrict modes, double *__restrict costs, uint8_t length);
void kvz_sort_modes_intra_luma(int8_t *__restrict modes, int8_t *__restrict trafo, double *__restrict costs, uint8_t length);

View file

@ -719,7 +719,7 @@ static int8_t search_intra_rdo(encoder_state_t * const state,
int8_t *intra_preds,
int modes_to_check,
int8_t modes[67], int8_t trafo[67], double costs[67],
int num_mip_modes,
int num_mip_modes_full,
int8_t mip_modes[32], int8_t mip_trafo[32], double mip_costs[32],
lcu_t *lcu,
uint8_t multi_ref_idx)
@ -756,14 +756,15 @@ static int8_t search_intra_rdo(encoder_state_t * const state,
}
// MIP_TODO: implement this inside the standard intra for loop. Code duplication is bad.
// MIP_TODO: loop through normal intra modes first
// MIP search
const int transp_off = num_mip_modes >> 1;
for (uint8_t mip_mode = 0; mip_mode < num_mip_modes; ++mip_mode) {
// Derive mip flag context id
uint8_t ctx_id = kvz_get_mip_flag_context(x_px, y_px, width, height, lcu, NULL);
int rdo_bitcost = kvz_luma_mode_bits(state, mip_modes[mip_mode], intra_preds, 0, num_mip_modes, ctx_id);
const int transp_off = num_mip_modes_full >> 1;
// Derive mip flag context id
uint8_t ctx_id = kvz_get_mip_flag_context(x_px, y_px, width, height, lcu, NULL);
for (uint8_t mip_mode = 0; mip_mode < num_mip_modes_full; ++mip_mode) {
int rdo_bitcost = kvz_luma_mode_bits(state, mip_modes[mip_mode], intra_preds, 0, transp_off, ctx_id);
mip_costs[mip_mode] = rdo_bitcost * (int)(state->lambda + 0.5); // MIP_TODO: check if this is also correct in the case when MIP is used.
mip_costs[mip_mode] = rdo_bitcost * (int)(state->lambda + 0.5);
const bool is_transposed = (mip_modes[mip_mode] >= transp_off ? true : false);
// There can be 32 MIP modes, but only mode numbers [0, 15] are ever written to bitstream.
@ -791,7 +792,6 @@ static int8_t search_intra_rdo(encoder_state_t * const state,
mip_costs[mip_mode] += mode_cost;
mip_trafo[mip_mode] = pred_cu.tr_idx;
// MIP_TODO: check if ET is viable when MIP is used
// Early termination if no coefficients has to be coded
if (state->encoder_control->cfg.intra_rdo_et && !cbf_is_set_any(pred_cu.cbf, depth)) {
modes_to_check = mip_mode + 1;
@ -834,8 +834,8 @@ static int8_t search_intra_rdo(encoder_state_t * const state,
// Update order according to new costs
kvz_sort_modes_intra_luma(modes, trafo, costs, modes_to_check);
bool use_mip = false;
if (num_mip_modes) {
kvz_sort_modes_intra_luma(mip_modes, mip_trafo, mip_costs, num_mip_modes);
if (num_mip_modes_full) {
kvz_sort_modes_intra_luma(mip_modes, mip_trafo, mip_costs, num_mip_modes_full);
if (costs[0] > mip_costs[0]) {
use_mip = true;
}
@ -854,7 +854,7 @@ static int8_t search_intra_rdo(encoder_state_t * const state,
pred_cu.intra.mode = mip_modes[0];
pred_cu.intra.mode_chroma = 0;
pred_cu.intra.multi_ref_idx = 0;
int transp_off = num_mip_modes >> 1;
int transp_off = num_mip_modes_full >> 1;
bool is_transposed = (mip_modes[0] >= transp_off ? true : false);
int8_t pred_mode = (is_transposed ? mip_modes[0] - transp_off : mip_modes[0]);
pred_cu.intra.mode = pred_mode;
@ -877,12 +877,14 @@ static int8_t search_intra_rdo(encoder_state_t * const state,
}
double kvz_luma_mode_bits(const encoder_state_t *state, int8_t luma_mode, const int8_t *intra_preds, const uint8_t multi_ref_idx, const uint8_t num_mip_modes, int mip_flag_ctx_id)
double kvz_luma_mode_bits(const encoder_state_t *state, int8_t luma_mode, const int8_t *intra_preds, const uint8_t multi_ref_idx, const uint8_t num_mip_modes_half, int mip_flag_ctx_id)
{
double mode_bits = 0.0;
bool enable_mip = state->encoder_control->cfg.mip ? (num_mip_modes > 0 ? true : false) : false;
bool enable_mip = state->encoder_control->cfg.mip;
bool mip_flag = enable_mip ? (num_mip_modes_half > 0 ? true : false) : false;
// Mip flag cost must be calculated even if mip is not used in this block
if (enable_mip) {
// Make a copy of state->cabac for bit cost estimation.
cabac_data_t state_cabac_copy;
@ -896,24 +898,25 @@ double kvz_luma_mode_bits(const encoder_state_t *state, int8_t luma_mode, const
cabac = &state_cabac_copy;
// Do cabac writes as normal
const int transp_off = num_mip_modes >> 1;
bool mip_flag = enable_mip;
const int transp_off = num_mip_modes_half;
const bool is_transposed = luma_mode >= transp_off ? true : false;
int8_t mip_mode = is_transposed ? luma_mode - transp_off : luma_mode;
// Write MIP flag
cabac->cur_ctx = &(cabac->ctx.mip_flag[mip_flag_ctx_id]);
CABAC_BIN(cabac, mip_flag, "mip_flag");
if (mip_flag) {
// Write MIP transpose flag & mode
CABAC_BIN_EP(cabac, is_transposed, "mip_transposed");
kvz_cabac_encode_trunc_bin(cabac, mip_mode, transp_off);
}
// Writes done. Get bit cost out of cabac
mode_bits += (23 - state_cabac_copy.bits_left) + (state_cabac_copy.num_buffered_bytes << 3); // MIP_TODO: check what this bit shifting means.
// Write is done. Get bit cost out of cabac
mode_bits += (23 - state_cabac_copy.bits_left) + (state_cabac_copy.num_buffered_bytes << 3);
}
else {
if (!mip_flag) {
int8_t mode_in_preds = -1;
for (int i = 0; i < INTRA_MPM_COUNT; ++i) {
if (luma_mode == intra_preds[i]) {
@ -1211,18 +1214,9 @@ void kvz_search_cu_intra(encoder_state_t * const state,
mip_modes[i] = i;
mip_costs[i] = MAX_INT;
}
// MIP_TODO: check for illegal block sizes.
if (width == 4 && height == 4) {
// Mip size_id = 0. Num modes = 32
num_mip_modes = 32;
}
else if (width == 4 || height == 4 || (width == 8 && height == 8)) {
// Mip size_id = 1. Num modes = 16
num_mip_modes = 16;
}
else {
// Mip size_id = 2. Num modes = 12
num_mip_modes = 12;
// MIP is not allowed for 64 x 4 or 4 x 64 blocks
if (!((width == 64 && height == 4) || (width == 4 && height == 64))) {
num_mip_modes = NUM_MIP_MODES_FULL(width, height);
}
}