diff --git a/src/cabac.c b/src/cabac.c index ed480e17..36931277 100644 --- a/src/cabac.c +++ b/src/cabac.c @@ -491,26 +491,28 @@ void kvz_cabac_write_coeff_remain_encry(struct encoder_state_t * const state, ca /** * \brief */ -void kvz_cabac_write_unary_max_symbol(cabac_data_t * const data, cabac_ctx_t * const ctx, uint32_t symbol, const int32_t offset, const uint32_t max_symbol) +void kvz_cabac_write_unary_max_symbol(cabac_data_t * const data, + cabac_ctx_t * const ctx, + uint32_t symbol, + const int32_t offset, + const uint32_t max_symbol, + double* bits_out) { int8_t code_last = max_symbol > symbol; assert(symbol <= max_symbol); if (!max_symbol) return; - - data->cur_ctx = &ctx[0]; - CABAC_BIN(data, symbol, "ums"); + + CABAC_FBITS_UPDATE(data, &ctx[0], symbol, *bits_out, "ums"); if (!symbol) return; while (--symbol) { - data->cur_ctx = &ctx[offset]; - CABAC_BIN(data, 1, "ums"); + CABAC_FBITS_UPDATE(data, &ctx[offset], 1, *bits_out, "ums"); } if (code_last) { - data->cur_ctx = &ctx[offset]; - CABAC_BIN(data, 0, "ums"); + CABAC_FBITS_UPDATE(data, &ctx[offset], 0,*bits_out, "ums"); } } diff --git a/src/cabac.h b/src/cabac.h index 62d59d9e..f9190045 100644 --- a/src/cabac.h +++ b/src/cabac.h @@ -125,8 +125,8 @@ void kvz_cabac_write_coeff_remain_encry(struct encoder_state_t * const state, ca uint32_t kvz_cabac_write_ep_ex_golomb(struct encoder_state_t * const state, cabac_data_t *data, uint32_t symbol, uint32_t count); void kvz_cabac_write_unary_max_symbol(cabac_data_t *data, cabac_ctx_t *ctx, - uint32_t symbol, int32_t offset, - uint32_t max_symbol); + uint32_t symbol, int32_t offset, + uint32_t max_symbol, double* bits_out); void kvz_cabac_write_unary_max_symbol_ep(cabac_data_t *data, unsigned int symbol, unsigned int max_symbol); extern const float kvz_f_entropy_bits[128]; diff --git a/src/encode_coding_tree.c b/src/encode_coding_tree.c index a847640e..b25494f4 100644 --- a/src/encode_coding_tree.c +++ b/src/encode_coding_tree.c @@ -290,7 +290,7 @@ static void encode_transform_coeff(encoder_state_t * const state, // cu_qp_delta_abs prefix cabac->cur_ctx = &cabac->ctx.cu_qp_delta_abs[0]; - kvz_cabac_write_unary_max_symbol(cabac, cabac->ctx.cu_qp_delta_abs, MIN(qp_delta_abs, 5), 1, 5); + kvz_cabac_write_unary_max_symbol(cabac, cabac->ctx.cu_qp_delta_abs, MIN(qp_delta_abs, 5), 1, 5, NULL); if (qp_delta_abs >= 5) { // cu_qp_delta_abs suffix @@ -412,7 +412,7 @@ void kvz_encode_inter_prediction_unit(encoder_state_t * const state, cabac->ctx.mvp_idx_model, CU_GET_MV_CAND(cur_cu, ref_list_idx), 1, - AMVP_MAX_NUM_CANDS - 1); + AMVP_MAX_NUM_CANDS - 1, bits_out); } // for ref_list } // if !merge @@ -467,7 +467,7 @@ static INLINE uint8_t intra_mode_encryption(encoder_state_t * const state, static void encode_intra_coding_unit(encoder_state_t * const state, cabac_data_t * const cabac, const cu_info_t * const cur_cu, - int x, int y, int depth, double* bits_out) + int x, int y, int depth, lcu_t* lcu, double* bits_out) { const videoframe_t * const frame = state->tile->frame; uint8_t intra_pred_mode_actual[4]; @@ -506,19 +506,19 @@ static void encode_intra_coding_unit(encoder_state_t * const state, for (int j = 0; j < num_pred_units; ++j) { const int pu_x = PU_GET_X(cur_cu->part_size, cu_width, x, j); const int pu_y = PU_GET_Y(cur_cu->part_size, cu_width, y, j); - const cu_info_t *cur_pu = kvz_cu_array_at_const(frame->cu_array, pu_x, pu_y); + const cu_info_t *cur_pu = lcu ? LCU_GET_CU_AT_PX(lcu, SUB_SCU(pu_x), SUB_SCU(pu_y)) : kvz_cu_array_at_const(frame->cu_array, pu_x, pu_y); const cu_info_t *left_pu = NULL; const cu_info_t *above_pu = NULL; if (pu_x > 0) { assert(pu_x >> 2 > 0); - left_pu = kvz_cu_array_at_const(frame->cu_array, pu_x - 1, pu_y); + left_pu = lcu ? LCU_GET_CU_AT_PX(lcu, SUB_SCU(pu_x -1), SUB_SCU(pu_y)) : kvz_cu_array_at_const(frame->cu_array, pu_x - 1, pu_y); } // Don't take the above PU across the LCU boundary. if (pu_y % LCU_WIDTH > 0 && pu_y > 0) { assert(pu_y >> 2 > 0); - above_pu = kvz_cu_array_at_const(frame->cu_array, pu_x, pu_y - 1); + above_pu = lcu ? LCU_GET_CU_AT_PX(lcu, SUB_SCU(pu_x), SUB_SCU(pu_y - 1)) : kvz_cu_array_at_const(frame->cu_array, pu_x, pu_y - 1); } if (do_crypto) { @@ -893,7 +893,7 @@ void kvz_encode_coding_tree(encoder_state_t * const state, } } } else if (cur_cu->type == CU_INTRA) { - encode_intra_coding_unit(state, cabac, cur_cu, x, y, depth, NULL); + encode_intra_coding_unit(state, cabac, cur_cu, x, y, depth, NULL, NULL); } #if ENABLE_PCM @@ -952,11 +952,11 @@ end: } -void kvz_mock_encode_coding_unit( +double kvz_mock_encode_coding_unit( encoder_state_t* const state, cabac_data_t* cabac, int x, int y, int depth, - lcu_t* lcu) { + lcu_t* lcu, cu_info_t* cur_cu) { double bits = 0; const encoder_control_t* const ctrl = state->encoder_control; @@ -964,9 +964,7 @@ void kvz_mock_encode_coding_unit( int y_local = SUB_SCU(y); const int cu_width = LCU_WIDTH >> depth; - const int half_cu = cu_width >> 1; - - const cu_info_t* cur_cu = LCU_GET_CU_AT_PX(lcu, x_local, y_local); + const cu_info_t* left_cu = NULL, *above_cu = NULL; if (x) { left_cu = LCU_GET_CU_AT_PX(lcu, x_local - 1, y_local); @@ -1037,7 +1035,7 @@ void kvz_mock_encode_coding_unit( } } } - return; + return bits; } } // Prediction mode @@ -1072,8 +1070,9 @@ void kvz_mock_encode_coding_unit( } } else if (cur_cu->type == CU_INTRA) { - encode_intra_coding_unit(state, cabac, cur_cu, x, y, depth, NULL); + encode_intra_coding_unit(state, cabac, cur_cu, x, y, depth, lcu, &bits); } + return bits; } diff --git a/src/encode_coding_tree.h b/src/encode_coding_tree.h index b8e29358..42a1a981 100644 --- a/src/encode_coding_tree.h +++ b/src/encode_coding_tree.h @@ -52,11 +52,11 @@ void kvz_encode_mvd(encoder_state_t * const state, int32_t mvd_ver, double* bits_out); -void kvz_mock_encode_coding_unit( +double kvz_mock_encode_coding_unit( encoder_state_t* const state, cabac_data_t* cabac, int x, int y, int depth, - lcu_t* lcu); + lcu_t* lcu, cu_info_t* cur_cu); void kvz_encode_inter_prediction_unit(encoder_state_t* const state, cabac_data_t* const cabac, diff --git a/src/rdo.c b/src/rdo.c index 5b6c3b49..fc0b2198 100644 --- a/src/rdo.c +++ b/src/rdo.c @@ -1081,8 +1081,8 @@ double kvz_calc_mvd_cost_cabac(const encoder_state_t * state, x - mv_cand[1][0], y - mv_cand[1][1], }; - uint32_t cand1_cost = kvz_get_mvd_coding_cost_cabac(state, cabac, mvd1.x, mvd1.y); - uint32_t cand2_cost = kvz_get_mvd_coding_cost_cabac(state, cabac, mvd2.x, mvd2.y); + double cand1_cost = kvz_get_mvd_coding_cost_cabac(state, cabac, mvd1.x, mvd1.y); + double cand2_cost = kvz_get_mvd_coding_cost_cabac(state, cabac, mvd2.x, mvd2.y); // Select candidate 1 if it has lower cost if (cand2_cost < cand1_cost) { @@ -1161,11 +1161,12 @@ double kvz_calc_mvd_cost_cabac(const encoder_state_t * state, // Signal which candidate MV to use kvz_cabac_write_unary_max_symbol( - cabac, - cabac->ctx.mvp_idx_model, - cur_mv_cand, - 1, - AMVP_MAX_NUM_CANDS - 1); + cabac, + cabac->ctx.mvp_idx_model, + cur_mv_cand, + 1, + AMVP_MAX_NUM_CANDS - 1, + NULL); } } } diff --git a/src/search.c b/src/search.c index ad24b501..1fc36566 100644 --- a/src/search.c +++ b/src/search.c @@ -37,6 +37,7 @@ #include "cabac.h" #include "encoder.h" +#include "encode_coding_tree.h" #include "imagelist.h" #include "inter.h" #include "intra.h" @@ -743,61 +744,19 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth, cabac_data_t* cabac = &state->search_cabac; cabac->update = 1; - if(depth < MAX_DEPTH) { - uint8_t split_model = get_ctx_cu_split_model(lcu, x, y, depth); - cabac_ctx_t* ctx = &(cabac->ctx.split_flag_model[split_model]); - CABAC_FBITS_UPDATE(cabac, ctx, 0, bits, "no_split_search"); + if(cur_cu->type != CU_INTRA || cur_cu->part_size == SIZE_2Nx2N) { + bits += kvz_mock_encode_coding_unit( + state, + cabac, + x, y, depth, + lcu, + cur_cu); } - else if(depth == MAX_DEPTH && cur_cu->type == CU_INTRA) { - // Add cost of intra part_size. - cabac_ctx_t* ctx = &(cabac->ctx.part_size_model[0]); - CABAC_FBITS_UPDATE(cabac, ctx, 0, bits, "no_split_search"); + else { + // Intra 4×4 PUs } - - double mode_bits = 0; - if (state->frame->slicetype != KVZ_SLICE_I) { - int ctx_skip = 0; - if (x > 0) { - ctx_skip += LCU_GET_CU_AT_PX(lcu, x_local - 1, y_local)->skipped; - } - if (y > 0) { - ctx_skip += LCU_GET_CU_AT_PX(lcu, x_local, y_local - 1)->skipped; - } - CABAC_FBITS_UPDATE(cabac, &(cabac->ctx.cu_skip_flag_model[ctx_skip]), cur_cu->skipped, mode_bits, "skip_flag"); - if (cur_cu->skipped) { - int16_t num_cand = state->encoder_control->cfg.max_merge; - if (num_cand > 1) { - for (int ui = 0; ui < num_cand - 1; ui++) { - int32_t symbol = (ui != cur_cu->merge_idx); - if (ui == 0) { - CABAC_FBITS_UPDATE(cabac, &(cabac->ctx.cu_merge_idx_ext_model), symbol, mode_bits, "MergeIndex"); - } - else { - CABAC_BIN_EP(cabac, symbol, "MergeIndex"); - mode_bits += 1; - } - if (symbol == 0) { - break; - } - } - } - } - - } - if (cur_cu->type == CU_INTRA) { - if(state->frame->slicetype != KVZ_SLICE_I) { - cabac_ctx_t* ctx = &(cabac->ctx.cu_pred_mode_model); - CABAC_FBITS_UPDATE(cabac, ctx, 1, mode_bits, "pred_mode_flag"); - } - mode_bits += calc_mode_bits(state, lcu, cur_cu, x, y); - } - else if (!cur_cu->skipped) { - cabac_ctx_t* ctx = &(cabac->ctx.cu_pred_mode_model); - CABAC_FBITS_UPDATE(cabac, ctx, 0, mode_bits, "pred_mode_flag"); - mode_bits += inter_bitcost; - } - bits += mode_bits; - cost = mode_bits * state->lambda; + + cost = bits * state->lambda; cost += kvz_cu_rd_cost_luma(state, x_local, y_local, depth, cur_cu, lcu, &bits); if (state->encoder_control->chroma_format != KVZ_CSP_400) {