[mip] Fix issue with invalid MIP modes written into cabac. Fix Mip mode cost estimation. Implement function to derive mip flag context id. Fix some asserts. Replace floor log 2 implementation with existing kvz math function.

This commit is contained in:
siivonek 2022-02-01 21:09:36 +02:00
parent d2c24c9a0c
commit df5cbbe82f
4 changed files with 47 additions and 101 deletions

View file

@ -873,28 +873,13 @@ static void encode_intra_coding_unit(encoder_state_t * const state,
} }
if (mip_flag) { if (mip_flag) {
assert(mip_mode >= 0 && mip_mode < 16 && "MIP mode must be between [0, 15]"); assert(mip_mode >= 0 && mip_mode < num_mip_modes && "Invalid MIP mode.");
} }
if (cur_cu->type == CU_INTRA && !cur_cu->bdpcmMode && enable_mip) { if (cur_cu->type == CU_INTRA && !cur_cu->bdpcmMode && enable_mip) {
// Derive mip flag context id
uint8_t ctx_id = 0;
const int cu_width = LCU_WIDTH >> depth; const int cu_width = LCU_WIDTH >> depth;
const int cu_height = cu_width; // TODO: height for non-square blocks const int cu_height = cu_width; // TODO: height for non-square blocks
const int pu_x = PU_GET_X(cur_cu->part_size, cu_width, x, 0); uint8_t ctx_id = kvz_get_mip_flag_context(x, y, cu_width, cu_height, NULL, frame->cu_array);
const int pu_y = PU_GET_Y(cur_cu->part_size, cu_height, y, 0);
if (pu_x > 0) {
assert(pu_x >> 2 > 0);
// Get mip flag from left PU
ctx_id = kvz_cu_array_at_const(frame->cu_array, pu_x - 1, pu_y)->intra.mip_flag ? 1 : 0;
}
if (pu_y % LCU_WIDTH > 0 && pu_y > 0) {
assert(pu_y >> 2 > 0);
// Get mip flag from above PU
ctx_id += kvz_cu_array_at_const(frame->cu_array, pu_x, pu_y - 1)->intra.mip_flag ? 1 : 0;
}
ctx_id = (cu_width > 2 * cu_height || cu_height > 2 * cu_width) ? 3 : ctx_id;
// Write MIP flag // Write MIP flag
cabac->cur_ctx = &(cabac->ctx.mip_flag[ctx_id]); cabac->cur_ctx = &(cabac->ctx.mip_flag[ctx_id]);

View file

@ -546,37 +546,41 @@ void kvz_predict_cclm(
} }
int kvz_get_mip_flag_context(int x, int y, int width, int height, lcu_t* const lcu, cu_array_t* const cu_a) {
assert(!(lcu && cu_a));
int context = 0;
if (lcu) {
int x_local = SUB_SCU(x);
int y_local = SUB_SCU(y);
if (x) {
context += LCU_GET_CU_AT_PX(lcu, x_local - 1, y_local)->intra.mip_flag;
}
if (y) {
context += LCU_GET_CU_AT_PX(lcu, x_local, y_local - 1)->intra.mip_flag;
}
context = (width > 2 * height || height > 2 * width) ? 3 : context;
}
else {
if (x > 0) {
context += kvz_cu_array_at_const(cu_a, x - 1, y)->intra.mip_flag;
}
if (y > 0) {
context += kvz_cu_array_at_const(cu_a, x, y - 1)->intra.mip_flag;
}
context = (width > 2 * height || height > 2 * width) ? 3 : context;
}
return context;
}
void kvz_mip_boundary_downsampling_1D(kvz_pixel* reduced_dst, const kvz_pixel* const ref_src, int src_len, int dst_len) void kvz_mip_boundary_downsampling_1D(kvz_pixel* reduced_dst, const kvz_pixel* const ref_src, int src_len, int dst_len)
{ {
if (dst_len < src_len) if (dst_len < src_len)
{ {
// Create reduced boundary by downsampling // Create reduced boundary by downsampling
uint16_t down_smp_factor = src_len / dst_len; uint16_t down_smp_factor = src_len / dst_len;
const int log2_factor = kvz_math_floor_log2(down_smp_factor);
// Calculate floor log2. MIP_TODO: find a better / faster solution
int tmp = 0;
if (down_smp_factor & 0xffff0000) {
down_smp_factor >>= 16;
tmp += 16;
}
if (down_smp_factor & 0xff00) {
down_smp_factor >>= 8;
tmp += 8;
}
if (down_smp_factor & 0xf0) {
down_smp_factor >>= 4;
tmp += 4;
}
if (down_smp_factor & 0xc) {
down_smp_factor >>= 2;
tmp += 2;
}
if (down_smp_factor & 0x2) {
down_smp_factor >>= 1;
tmp += 1;
}
const int log2_factor = tmp;
const int rounding_offset = (1 << (log2_factor - 1)); const int rounding_offset = (1 << (log2_factor - 1));
uint16_t src_idx = 0; uint16_t src_idx = 0;
@ -667,31 +671,7 @@ void kvz_mip_pred_upsampling_1D(kvz_pixel* const dst, const kvz_pixel* const src
const uint16_t boundary_step, const uint16_t boundary_step,
const uint16_t ups_factor) const uint16_t ups_factor)
{ {
// Calculate floor log2. MIP_TODO: find a better / faster solution const int log2_factor = kvz_math_floor_log2(ups_factor);
uint16_t upsample_factor = ups_factor;
int tmp = 0;
if (upsample_factor & 0xffff0000) {
upsample_factor >>= 16;
tmp += 16;
}
if (upsample_factor & 0xff00) {
upsample_factor >>= 8;
tmp += 8;
}
if (upsample_factor & 0xf0) {
upsample_factor >>= 4;
tmp += 4;
}
if (upsample_factor & 0xc) {
upsample_factor >>= 2;
tmp += 2;
}
if (upsample_factor & 0x2) {
upsample_factor >>= 1;
tmp += 1;
}
const int log2_factor = tmp;
assert(ups_factor >= 2 && "Upsampling factor must be at least 2."); assert(ups_factor >= 2 && "Upsampling factor must be at least 2.");
const int rounding_offset = 1 << (log2_factor - 1); const int rounding_offset = 1 << (log2_factor - 1);

View file

@ -150,6 +150,8 @@ void kvz_predict_cclm(
cclm_parameters_t* cclm_params cclm_parameters_t* cclm_params
); );
int kvz_get_mip_flag_context(int x, int y, int width, int height, lcu_t* const lcu, cu_array_t* const cu_a);
void kvz_mip_predict( void kvz_mip_predict(
encoder_state_t const * const state, encoder_state_t const * const state,
kvz_intra_references * refs, kvz_intra_references * refs,

View file

@ -726,11 +726,11 @@ static int8_t search_intra_rdo(encoder_state_t * const state,
{ {
const int tr_depth = CLIP(1, MAX_PU_DEPTH, depth + state->encoder_control->cfg.tr_depth_intra); const int tr_depth = CLIP(1, MAX_PU_DEPTH, depth + state->encoder_control->cfg.tr_depth_intra);
const int width = LCU_WIDTH >> depth; const int width = LCU_WIDTH >> depth;
const int height = width; // TODO: proper height for non-square blocks
kvz_pixel orig_block[LCU_WIDTH * LCU_WIDTH + 1]; kvz_pixel orig_block[LCU_WIDTH * LCU_WIDTH + 1];
// TODO: height for non-square blocks kvz_pixels_blit(orig, orig_block, width, height, origstride, width);
kvz_pixels_blit(orig, orig_block, width, width, origstride, width);
// Check that the predicted modes are in the RDO mode list // Check that the predicted modes are in the RDO mode list
if (modes_to_check < 67) { if (modes_to_check < 67) {
@ -756,46 +756,19 @@ static int8_t search_intra_rdo(encoder_state_t * const state,
} }
// MIP_TODO: implement this inside the standard intra for loop. Code duplication is bad. // MIP_TODO: implement this inside the standard intra for loop. Code duplication is bad.
// MIP_TODO: deriving mip flag context id could be done in it's own function since the exact same code is used in encode_coding_tree.c
// MIP search // MIP search
const int transp_off = num_mip_modes >> 1; const int transp_off = num_mip_modes >> 1;
for (uint8_t mip_mode = 0; mip_mode < num_mip_modes; ++mip_mode) { for (uint8_t mip_mode = 0; mip_mode < num_mip_modes; ++mip_mode) {
// Derive mip flag context id // Derive mip flag context id
uint8_t ctx_id = 0; uint8_t ctx_id = kvz_get_mip_flag_context(x_px, y_px, width, height, lcu, NULL);
const videoframe_t* const frame = state->tile->frame;
const vector2d_t lcu_px = { SUB_SCU(x_px), SUB_SCU(y_px) };
cu_info_t* cur_cu;
cur_cu = LCU_GET_CU_AT_PX(lcu, lcu_px.x, lcu_px.y);
const int cu_width = width;
const int cu_height = cu_width; // TODO: height for non-square blocks
const int pu_x = PU_GET_X(cur_cu->part_size, cu_width, x_px, 0);
const int pu_y = PU_GET_Y(cur_cu->part_size, cu_width, y_px, 0);
const cu_info_t* left_pu = NULL;
const cu_info_t* above_pu = NULL;
if (pu_x > 0) {
assert(pu_x >> 2 > 0);
left_pu = kvz_cu_array_at_const(frame->cu_array, pu_x - 1, pu_y + cu_width - 1);
}
if (left_pu != NULL) {
ctx_id = left_pu->intra.mip_flag ? 1 : 0;
}
// Don't take the above PU across the LCU boundary.
if (pu_y % LCU_WIDTH > 0 && pu_y > 0) {
assert(pu_y >> 2 > 0);
above_pu = kvz_cu_array_at_const(frame->cu_array, pu_x + cu_width - 1, pu_y - 1);
}
if (above_pu != NULL) {
ctx_id += above_pu->intra.mip_flag ? 1 : 0;
}
ctx_id = (cu_width > 2 * cu_height || cu_height > 2 * cu_width) ? 3 : ctx_id;
int rdo_bitcost = kvz_luma_mode_bits(state, mip_modes[mip_mode], intra_preds, 0, num_mip_modes, ctx_id); int rdo_bitcost = kvz_luma_mode_bits(state, mip_modes[mip_mode], intra_preds, 0, num_mip_modes, ctx_id);
mip_costs[mip_mode] = rdo_bitcost * (int)(state->lambda + 0.5); // MIP_TODO: check if this is also correct in the case when MIP is used. mip_costs[mip_mode] = rdo_bitcost * (int)(state->lambda + 0.5); // MIP_TODO: check if this is also correct in the case when MIP is used.
const bool is_transposed = (mip_modes[mip_mode] >= transp_off ? true : false); const bool is_transposed = (mip_modes[mip_mode] >= transp_off ? true : false);
// There can be 32 MIP modes, but only mode numbers [0, 15] are ever written to bitstream. // There can be 32 MIP modes, but only mode numbers [0, 15] are ever written to bitstream.
// Modes [16, 31] are indicated with the separate transpose flag. // Half of the modes [16, 31] are indicated with the separate transpose flag.
// Number of possible modes is less for larger blocks.
int8_t pred_mode = (is_transposed ? mip_modes[mip_mode] - transp_off : mip_modes[mip_mode]); int8_t pred_mode = (is_transposed ? mip_modes[mip_mode] - transp_off : mip_modes[mip_mode]);
// Perform transform split search and save mode RD cost for the best one. // Perform transform split search and save mode RD cost for the best one.
@ -1244,11 +1217,11 @@ void kvz_search_cu_intra(encoder_state_t * const state,
num_mip_modes = 32; num_mip_modes = 32;
} }
else if (width == 4 || height == 4 || (width == 8 && height == 8)) { else if (width == 4 || height == 4 || (width == 8 && height == 8)) {
// Mip size_id = 0. Num modes = 16 // Mip size_id = 1. Num modes = 16
num_mip_modes = 16; num_mip_modes = 16;
} }
else { else {
// Mip size_id = 0. Num modes = 12 // Mip size_id = 2. Num modes = 12
num_mip_modes = 12; num_mip_modes = 12;
} }
} }
@ -1358,6 +1331,12 @@ void kvz_search_cu_intra(encoder_state_t * const state,
} }
} }
if (tmp_mip_flag) {
// Transform best mode index to proper form.
// Max mode index is half of max number of modes - 1 (i. e. for size id 2, max mode id is 5)
tmp_best_mode = (tmp_mip_transp ? tmp_best_mode - (num_mip_modes >> 1) : tmp_best_mode);
}
*mode_out = tmp_best_mode; *mode_out = tmp_best_mode;
*trafo_out = tmp_best_trafo; *trafo_out = tmp_best_trafo;
*cost_out = tmp_best_cost; *cost_out = tmp_best_cost;