diff --git a/src/encode_coding_tree.c b/src/encode_coding_tree.c index c398fa97..77b43f5f 100644 --- a/src/encode_coding_tree.c +++ b/src/encode_coding_tree.c @@ -873,28 +873,13 @@ static void encode_intra_coding_unit(encoder_state_t * const state, } if (mip_flag) { - assert(mip_mode >= 0 && mip_mode < 16 && "MIP mode must be between [0, 15]"); + assert(mip_mode >= 0 && mip_mode < num_mip_modes && "Invalid MIP mode."); } if (cur_cu->type == CU_INTRA && !cur_cu->bdpcmMode && enable_mip) { - // Derive mip flag context id - uint8_t ctx_id = 0; const int cu_width = LCU_WIDTH >> depth; const int cu_height = cu_width; // TODO: height for non-square blocks - const int pu_x = PU_GET_X(cur_cu->part_size, cu_width, x, 0); - const int pu_y = PU_GET_Y(cur_cu->part_size, cu_height, y, 0); - - if (pu_x > 0) { - assert(pu_x >> 2 > 0); - // Get mip flag from left PU - ctx_id = kvz_cu_array_at_const(frame->cu_array, pu_x - 1, pu_y)->intra.mip_flag ? 1 : 0; - } - if (pu_y % LCU_WIDTH > 0 && pu_y > 0) { - assert(pu_y >> 2 > 0); - // Get mip flag from above PU - ctx_id += kvz_cu_array_at_const(frame->cu_array, pu_x, pu_y - 1)->intra.mip_flag ? 1 : 0; - } - ctx_id = (cu_width > 2 * cu_height || cu_height > 2 * cu_width) ? 3 : ctx_id; + uint8_t ctx_id = kvz_get_mip_flag_context(x, y, cu_width, cu_height, NULL, frame->cu_array); // Write MIP flag cabac->cur_ctx = &(cabac->ctx.mip_flag[ctx_id]); diff --git a/src/intra.c b/src/intra.c index fa7f27d6..2d3d0939 100644 --- a/src/intra.c +++ b/src/intra.c @@ -546,37 +546,41 @@ void kvz_predict_cclm( } +int kvz_get_mip_flag_context(int x, int y, int width, int height, lcu_t* const lcu, cu_array_t* const cu_a) { + assert(!(lcu && cu_a)); + int context = 0; + + if (lcu) { + int x_local = SUB_SCU(x); + int y_local = SUB_SCU(y); + if (x) { + context += LCU_GET_CU_AT_PX(lcu, x_local - 1, y_local)->intra.mip_flag; + } + if (y) { + context += LCU_GET_CU_AT_PX(lcu, x_local, y_local - 1)->intra.mip_flag; + } + context = (width > 2 * height || height > 2 * width) ? 3 : context; + } + else { + if (x > 0) { + context += kvz_cu_array_at_const(cu_a, x - 1, y)->intra.mip_flag; + } + if (y > 0) { + context += kvz_cu_array_at_const(cu_a, x, y - 1)->intra.mip_flag; + } + context = (width > 2 * height || height > 2 * width) ? 3 : context; + } + return context; +} + + void kvz_mip_boundary_downsampling_1D(kvz_pixel* reduced_dst, const kvz_pixel* const ref_src, int src_len, int dst_len) { if (dst_len < src_len) { // Create reduced boundary by downsampling uint16_t down_smp_factor = src_len / dst_len; - - // Calculate floor log2. MIP_TODO: find a better / faster solution - int tmp = 0; - if (down_smp_factor & 0xffff0000) { - down_smp_factor >>= 16; - tmp += 16; - } - if (down_smp_factor & 0xff00) { - down_smp_factor >>= 8; - tmp += 8; - } - if (down_smp_factor & 0xf0) { - down_smp_factor >>= 4; - tmp += 4; - } - if (down_smp_factor & 0xc) { - down_smp_factor >>= 2; - tmp += 2; - } - if (down_smp_factor & 0x2) { - down_smp_factor >>= 1; - tmp += 1; - } - - const int log2_factor = tmp; + const int log2_factor = kvz_math_floor_log2(down_smp_factor); const int rounding_offset = (1 << (log2_factor - 1)); uint16_t src_idx = 0; @@ -667,31 +671,7 @@ void kvz_mip_pred_upsampling_1D(kvz_pixel* const dst, const kvz_pixel* const src const uint16_t boundary_step, const uint16_t ups_factor) { - // Calculate floor log2. MIP_TODO: find a better / faster solution - uint16_t upsample_factor = ups_factor; - int tmp = 0; - if (upsample_factor & 0xffff0000) { - upsample_factor >>= 16; - tmp += 16; - } - if (upsample_factor & 0xff00) { - upsample_factor >>= 8; - tmp += 8; - } - if (upsample_factor & 0xf0) { - upsample_factor >>= 4; - tmp += 4; - } - if (upsample_factor & 0xc) { - upsample_factor >>= 2; - tmp += 2; - } - if (upsample_factor & 0x2) { - upsample_factor >>= 1; - tmp += 1; - } - - const int log2_factor = tmp; + const int log2_factor = kvz_math_floor_log2(ups_factor); assert(ups_factor >= 2 && "Upsampling factor must be at least 2."); const int rounding_offset = 1 << (log2_factor - 1); diff --git a/src/intra.h b/src/intra.h index 436e20bf..7bd27e1f 100644 --- a/src/intra.h +++ b/src/intra.h @@ -150,6 +150,8 @@ void kvz_predict_cclm( cclm_parameters_t* cclm_params ); +int kvz_get_mip_flag_context(int x, int y, int width, int height, lcu_t* const lcu, cu_array_t* const cu_a); + void kvz_mip_predict( encoder_state_t const * const state, kvz_intra_references * refs, diff --git a/src/search_intra.c b/src/search_intra.c index 6800bfef..efeda0e5 100644 --- a/src/search_intra.c +++ b/src/search_intra.c @@ -726,11 +726,11 @@ static int8_t search_intra_rdo(encoder_state_t * const state, { const int tr_depth = CLIP(1, MAX_PU_DEPTH, depth + state->encoder_control->cfg.tr_depth_intra); const int width = LCU_WIDTH >> depth; + const int height = width; // TODO: proper height for non-square blocks kvz_pixel orig_block[LCU_WIDTH * LCU_WIDTH + 1]; - // TODO: height for non-square blocks - kvz_pixels_blit(orig, orig_block, width, width, origstride, width); + kvz_pixels_blit(orig, orig_block, width, height, origstride, width); // Check that the predicted modes are in the RDO mode list if (modes_to_check < 67) { @@ -756,46 +756,19 @@ static int8_t search_intra_rdo(encoder_state_t * const state, } // MIP_TODO: implement this inside the standard intra for loop. Code duplication is bad. - // MIP_TODO: deriving mip flag context id could be done in it's own function since the exact same code is used in encode_coding_tree.c // MIP search const int transp_off = num_mip_modes >> 1; for (uint8_t mip_mode = 0; mip_mode < num_mip_modes; ++mip_mode) { // Derive mip flag context id - uint8_t ctx_id = 0; - const videoframe_t* const frame = state->tile->frame; - const vector2d_t lcu_px = { SUB_SCU(x_px), SUB_SCU(y_px) }; - cu_info_t* cur_cu; - cur_cu = LCU_GET_CU_AT_PX(lcu, lcu_px.x, lcu_px.y); - const int cu_width = width; - const int cu_height = cu_width; // TODO: height for non-square blocks - const int pu_x = PU_GET_X(cur_cu->part_size, cu_width, x_px, 0); - const int pu_y = PU_GET_Y(cur_cu->part_size, cu_width, y_px, 0); - const cu_info_t* left_pu = NULL; - const cu_info_t* above_pu = NULL; - - if (pu_x > 0) { - assert(pu_x >> 2 > 0); - left_pu = kvz_cu_array_at_const(frame->cu_array, pu_x - 1, pu_y + cu_width - 1); - } - if (left_pu != NULL) { - ctx_id = left_pu->intra.mip_flag ? 1 : 0; - } - // Don't take the above PU across the LCU boundary. - if (pu_y % LCU_WIDTH > 0 && pu_y > 0) { - assert(pu_y >> 2 > 0); - above_pu = kvz_cu_array_at_const(frame->cu_array, pu_x + cu_width - 1, pu_y - 1); - } - if (above_pu != NULL) { - ctx_id += above_pu->intra.mip_flag ? 1 : 0; - } - ctx_id = (cu_width > 2 * cu_height || cu_height > 2 * cu_width) ? 3 : ctx_id; + uint8_t ctx_id = kvz_get_mip_flag_context(x_px, y_px, width, height, lcu, NULL); int rdo_bitcost = kvz_luma_mode_bits(state, mip_modes[mip_mode], intra_preds, 0, num_mip_modes, ctx_id); mip_costs[mip_mode] = rdo_bitcost * (int)(state->lambda + 0.5); // MIP_TODO: check if this is also correct in the case when MIP is used. const bool is_transposed = (mip_modes[mip_mode] >= transp_off ? true : false); // There can be 32 MIP modes, but only mode numbers [0, 15] are ever written to bitstream. - // Modes [16, 31] are indicated with the separate transpose flag. + // Half of the modes [16, 31] are indicated with the separate transpose flag. + // Number of possible modes is less for larger blocks. int8_t pred_mode = (is_transposed ? mip_modes[mip_mode] - transp_off : mip_modes[mip_mode]); // Perform transform split search and save mode RD cost for the best one. @@ -1244,11 +1217,11 @@ void kvz_search_cu_intra(encoder_state_t * const state, num_mip_modes = 32; } else if (width == 4 || height == 4 || (width == 8 && height == 8)) { - // Mip size_id = 0. Num modes = 16 + // Mip size_id = 1. Num modes = 16 num_mip_modes = 16; } else { - // Mip size_id = 0. Num modes = 12 + // Mip size_id = 2. Num modes = 12 num_mip_modes = 12; } } @@ -1358,6 +1331,12 @@ void kvz_search_cu_intra(encoder_state_t * const state, } } + if (tmp_mip_flag) { + // Transform best mode index to proper form. + // Max mode index is half of max number of modes - 1 (i. e. for size id 2, max mode id is 5) + tmp_best_mode = (tmp_mip_transp ? tmp_best_mode - (num_mip_modes >> 1) : tmp_best_mode); + } + *mode_out = tmp_best_mode; *trafo_out = tmp_best_trafo; *cost_out = tmp_best_cost;