diff --git a/src/cabac.c b/src/cabac.c index 7f5b92c2..5842edbe 100644 --- a/src/cabac.c +++ b/src/cabac.c @@ -97,6 +97,7 @@ void kvz_cabac_start(cabac_data_t * const data) data->num_buffered_bytes = 0; data->buffered_byte = 0xff; data->only_count = 0; // By default, write bits out + data->update = 0; } /** diff --git a/src/search.c b/src/search.c index 2cb34608..a0534bf4 100644 --- a/src/search.c +++ b/src/search.c @@ -265,17 +265,27 @@ double kvz_cu_rd_cost_luma(encoder_state_t *const state, const uint8_t tr_depth = tr_cu->tr_depth - depth; + cabac_data_t* cabac = &state->search_cabac; + // Add transform_tree split_transform_flag bit cost. bool intra_split_flag = pred_cu->type == CU_INTRA && pred_cu->part_size == SIZE_NxN && depth == 3; + int max_tr_depth; + if (tr_cu->type == CU_INTRA) { + max_tr_depth = state->encoder_control->cfg.tr_depth_intra + intra_split_flag; + } + else { + max_tr_depth = state->encoder_control->tr_depth_inter; + } if (width <= TR_MAX_WIDTH && width > TR_MIN_WIDTH - && !intra_split_flag) + && !intra_split_flag + && tr_depth < max_tr_depth) { - const cabac_ctx_t *ctx = &(state->search_cabac.ctx.trans_subdiv_model[5 - (6 - depth)]); + const cabac_ctx_t *ctx = &(cabac->ctx.trans_subdiv_model[5 - (6 - depth)]); tr_tree_bits += CTX_ENTROPY_FBITS(ctx, tr_depth > 0); - if (state->search_cabac.update) { - state->search_cabac.cur_ctx = ctx; - CABAC_BIN(&state->search_cabac, tr_depth > 0, "tr_split_search"); + if (cabac->update) { + cabac->cur_ctx = ctx; + CABAC_BIN(cabac, tr_depth > 0, "tr_split_search"); } *bit_cost += tr_tree_bits; } @@ -298,14 +308,28 @@ double kvz_cu_rd_cost_luma(encoder_state_t *const state, cbf_is_set(tr_cu->cbf, depth, COLOR_U) || cbf_is_set(tr_cu->cbf, depth, COLOR_V)) { - const cabac_ctx_t *ctx = &(state->search_cabac.ctx.qt_cbf_model_luma[!tr_depth]); + const cabac_ctx_t *ctx = &(cabac->ctx.qt_cbf_model_luma[!tr_depth]); int is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_Y); - tr_tree_bits += CTX_ENTROPY_FBITS(ctx, is_set); - if (state->search_cabac.update) { + if (cabac->update) { + // Because these need to be coded before the luma cbf they also need to be counted + // before the cabac state changes. However, since this branch is only executed when + // calculating the last RD cost it is not problem to include the chroma cbf costs in + // luma, because the chroma cost is calculated right after the luma cost. + if (state->encoder_control->chroma_format != KVZ_CSP_400) { + const cabac_ctx_t* cr_ctx = &(state->search_cabac.ctx.qt_cbf_model_chroma[tr_depth]); + cabac->cur_ctx = cr_ctx; + int u_is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_U); + int v_is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_V); + tr_tree_bits += CTX_ENTROPY_FBITS(cr_ctx, u_is_set); + CABAC_BIN(cabac, u_is_set, "cbf_cb_search"); + tr_tree_bits += CTX_ENTROPY_FBITS(cr_ctx, v_is_set); + CABAC_BIN(cabac, v_is_set, "cbf_cr_search"); + } + tr_tree_bits += CTX_ENTROPY_FBITS(ctx, is_set); + *bit_cost += tr_tree_bits; state->search_cabac.cur_ctx = ctx; CABAC_BIN(&state->search_cabac, is_set, "luma_cbf_search"); } - *bit_cost += CTX_ENTROPY_FBITS(ctx, is_set); } @@ -353,7 +377,8 @@ double kvz_cu_rd_cost_chroma(const encoder_state_t *const state, return 0; } - if (depth < MAX_PU_DEPTH) { + // See luma for why the second condition + if (depth < MAX_PU_DEPTH && !state->search_cabac.update) { const int tr_depth = depth - pred_cu->depth; const cabac_ctx_t *ctx = &(state->search_cabac.ctx.qt_cbf_model_chroma[tr_depth]); if (tr_depth == 0 || cbf_is_set(pred_cu->cbf, depth - 1, COLOR_U)) { @@ -712,12 +737,21 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth, double bits = 0; state->search_cabac.update = 1; - uint8_t split_model = get_ctx_cu_split_model(lcu, x, y, depth); - cabac_ctx_t* ctx = &(state->search_cabac.ctx.split_flag_model[split_model]); - state->search_cabac.cur_ctx = ctx; - // TODO: intra 4x4 PUs use different method - bits += CTX_ENTROPY_FBITS(ctx, 0); - CABAC_BIN(&state->search_cabac, 0, "no_split_search"); + if(depth < MAX_DEPTH) { + uint8_t split_model = get_ctx_cu_split_model(lcu, x, y, depth); + cabac_ctx_t* ctx = &(state->search_cabac.ctx.split_flag_model[split_model]); + state->search_cabac.cur_ctx = ctx; + bits += CTX_ENTROPY_FBITS(ctx, 0); + CABAC_BIN(&state->search_cabac, 0, "no_split_search"); + } + else if(depth == MAX_DEPTH && cur_cu->type == CU_INTRA) { + // Add cost of intra part_size. + const cabac_ctx_t* ctx = &(state->search_cabac.ctx.part_size_model[0]); + bits += CTX_ENTROPY_FBITS(ctx, 1); // NxN + state->search_cabac.cur_ctx = ctx; + FILE_BITS(CTX_ENTROPY_FBITS(ctx, 1), x, y, depth, "split"); + CABAC_BIN(&state->search_cabac, 1, "split_search"); + } double mode_bits; if (cur_cu->type == CU_INTRA) { @@ -776,6 +810,7 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth, cabac_data_t post_seach_cabac; memcpy(&post_seach_cabac, &state->search_cabac, sizeof(post_seach_cabac)); memcpy(&state->search_cabac, &pre_search_cabac, sizeof(post_seach_cabac)); + state->search_cabac.update = 1; if (depth < MAX_DEPTH) { // Add cost of cu_split_flag. @@ -792,9 +827,10 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth, const cabac_ctx_t *ctx = &(state->search_cabac.ctx.part_size_model[0]); split_cost += CTX_ENTROPY_FBITS(ctx, 0) * state->lambda; // NxN state->search_cabac.cur_ctx = ctx; - FILE_BITS(CTX_ENTROPY_FBITS(ctx, 1), x, y, depth, "split"); - CABAC_BIN(&state->search_cabac, 1, "split_search"); + FILE_BITS(CTX_ENTROPY_FBITS(ctx, 0), x, y, depth, "split"); + CABAC_BIN(&state->search_cabac, 0, "split_search"); } + state->search_cabac.update = 0; // If skip mode was selected for the block, skip further search. // Skip mode means there's no coefficients in the block, so splitting @@ -1023,6 +1059,8 @@ static void copy_lcu_to_cu_data(const encoder_state_t * const state, int x_px, i void kvz_search_lcu(encoder_state_t * const state, const int x, const int y, const yuv_t * const hor_buf, const yuv_t * const ver_buf) { if (bit_cost_file == NULL) bit_cost_file = fopen("bits_file.txt", "w"); + memcpy(&state->search_cabac, &state->cabac, sizeof(cabac_data_t)); + state->search_cabac.only_count = 1; assert(x % LCU_WIDTH == 0); assert(y % LCU_WIDTH == 0);