count all non-tr-depth related bits correctly

This commit is contained in:
Joose Sainio 2021-12-07 08:13:08 +02:00
parent 53264bc764
commit 9ed8d0a7d9
2 changed files with 57 additions and 18 deletions

View file

@ -97,6 +97,7 @@ void kvz_cabac_start(cabac_data_t * const data)
data->num_buffered_bytes = 0; data->num_buffered_bytes = 0;
data->buffered_byte = 0xff; data->buffered_byte = 0xff;
data->only_count = 0; // By default, write bits out data->only_count = 0; // By default, write bits out
data->update = 0;
} }
/** /**

View file

@ -265,17 +265,27 @@ double kvz_cu_rd_cost_luma(encoder_state_t *const state,
const uint8_t tr_depth = tr_cu->tr_depth - depth; const uint8_t tr_depth = tr_cu->tr_depth - depth;
cabac_data_t* cabac = &state->search_cabac;
// Add transform_tree split_transform_flag bit cost. // Add transform_tree split_transform_flag bit cost.
bool intra_split_flag = pred_cu->type == CU_INTRA && pred_cu->part_size == SIZE_NxN && depth == 3; bool intra_split_flag = pred_cu->type == CU_INTRA && pred_cu->part_size == SIZE_NxN && depth == 3;
int max_tr_depth;
if (tr_cu->type == CU_INTRA) {
max_tr_depth = state->encoder_control->cfg.tr_depth_intra + intra_split_flag;
}
else {
max_tr_depth = state->encoder_control->tr_depth_inter;
}
if (width <= TR_MAX_WIDTH if (width <= TR_MAX_WIDTH
&& width > TR_MIN_WIDTH && width > TR_MIN_WIDTH
&& !intra_split_flag) && !intra_split_flag
&& tr_depth < max_tr_depth)
{ {
const cabac_ctx_t *ctx = &(state->search_cabac.ctx.trans_subdiv_model[5 - (6 - depth)]); const cabac_ctx_t *ctx = &(cabac->ctx.trans_subdiv_model[5 - (6 - depth)]);
tr_tree_bits += CTX_ENTROPY_FBITS(ctx, tr_depth > 0); tr_tree_bits += CTX_ENTROPY_FBITS(ctx, tr_depth > 0);
if (state->search_cabac.update) { if (cabac->update) {
state->search_cabac.cur_ctx = ctx; cabac->cur_ctx = ctx;
CABAC_BIN(&state->search_cabac, tr_depth > 0, "tr_split_search"); CABAC_BIN(cabac, tr_depth > 0, "tr_split_search");
} }
*bit_cost += tr_tree_bits; *bit_cost += tr_tree_bits;
} }
@ -298,14 +308,28 @@ double kvz_cu_rd_cost_luma(encoder_state_t *const state,
cbf_is_set(tr_cu->cbf, depth, COLOR_U) || cbf_is_set(tr_cu->cbf, depth, COLOR_U) ||
cbf_is_set(tr_cu->cbf, depth, COLOR_V)) cbf_is_set(tr_cu->cbf, depth, COLOR_V))
{ {
const cabac_ctx_t *ctx = &(state->search_cabac.ctx.qt_cbf_model_luma[!tr_depth]); const cabac_ctx_t *ctx = &(cabac->ctx.qt_cbf_model_luma[!tr_depth]);
int is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_Y); int is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_Y);
tr_tree_bits += CTX_ENTROPY_FBITS(ctx, is_set); if (cabac->update) {
if (state->search_cabac.update) { // Because these need to be coded before the luma cbf they also need to be counted
// before the cabac state changes. However, since this branch is only executed when
// calculating the last RD cost it is not problem to include the chroma cbf costs in
// luma, because the chroma cost is calculated right after the luma cost.
if (state->encoder_control->chroma_format != KVZ_CSP_400) {
const cabac_ctx_t* cr_ctx = &(state->search_cabac.ctx.qt_cbf_model_chroma[tr_depth]);
cabac->cur_ctx = cr_ctx;
int u_is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_U);
int v_is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_V);
tr_tree_bits += CTX_ENTROPY_FBITS(cr_ctx, u_is_set);
CABAC_BIN(cabac, u_is_set, "cbf_cb_search");
tr_tree_bits += CTX_ENTROPY_FBITS(cr_ctx, v_is_set);
CABAC_BIN(cabac, v_is_set, "cbf_cr_search");
}
tr_tree_bits += CTX_ENTROPY_FBITS(ctx, is_set);
*bit_cost += tr_tree_bits;
state->search_cabac.cur_ctx = ctx; state->search_cabac.cur_ctx = ctx;
CABAC_BIN(&state->search_cabac, is_set, "luma_cbf_search"); CABAC_BIN(&state->search_cabac, is_set, "luma_cbf_search");
} }
*bit_cost += CTX_ENTROPY_FBITS(ctx, is_set);
} }
@ -353,7 +377,8 @@ double kvz_cu_rd_cost_chroma(const encoder_state_t *const state,
return 0; return 0;
} }
if (depth < MAX_PU_DEPTH) { // See luma for why the second condition
if (depth < MAX_PU_DEPTH && !state->search_cabac.update) {
const int tr_depth = depth - pred_cu->depth; const int tr_depth = depth - pred_cu->depth;
const cabac_ctx_t *ctx = &(state->search_cabac.ctx.qt_cbf_model_chroma[tr_depth]); const cabac_ctx_t *ctx = &(state->search_cabac.ctx.qt_cbf_model_chroma[tr_depth]);
if (tr_depth == 0 || cbf_is_set(pred_cu->cbf, depth - 1, COLOR_U)) { if (tr_depth == 0 || cbf_is_set(pred_cu->cbf, depth - 1, COLOR_U)) {
@ -712,12 +737,21 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth,
double bits = 0; double bits = 0;
state->search_cabac.update = 1; state->search_cabac.update = 1;
uint8_t split_model = get_ctx_cu_split_model(lcu, x, y, depth); if(depth < MAX_DEPTH) {
cabac_ctx_t* ctx = &(state->search_cabac.ctx.split_flag_model[split_model]); uint8_t split_model = get_ctx_cu_split_model(lcu, x, y, depth);
state->search_cabac.cur_ctx = ctx; cabac_ctx_t* ctx = &(state->search_cabac.ctx.split_flag_model[split_model]);
// TODO: intra 4x4 PUs use different method state->search_cabac.cur_ctx = ctx;
bits += CTX_ENTROPY_FBITS(ctx, 0); bits += CTX_ENTROPY_FBITS(ctx, 0);
CABAC_BIN(&state->search_cabac, 0, "no_split_search"); CABAC_BIN(&state->search_cabac, 0, "no_split_search");
}
else if(depth == MAX_DEPTH && cur_cu->type == CU_INTRA) {
// Add cost of intra part_size.
const cabac_ctx_t* ctx = &(state->search_cabac.ctx.part_size_model[0]);
bits += CTX_ENTROPY_FBITS(ctx, 1); // NxN
state->search_cabac.cur_ctx = ctx;
FILE_BITS(CTX_ENTROPY_FBITS(ctx, 1), x, y, depth, "split");
CABAC_BIN(&state->search_cabac, 1, "split_search");
}
double mode_bits; double mode_bits;
if (cur_cu->type == CU_INTRA) { if (cur_cu->type == CU_INTRA) {
@ -776,6 +810,7 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth,
cabac_data_t post_seach_cabac; cabac_data_t post_seach_cabac;
memcpy(&post_seach_cabac, &state->search_cabac, sizeof(post_seach_cabac)); memcpy(&post_seach_cabac, &state->search_cabac, sizeof(post_seach_cabac));
memcpy(&state->search_cabac, &pre_search_cabac, sizeof(post_seach_cabac)); memcpy(&state->search_cabac, &pre_search_cabac, sizeof(post_seach_cabac));
state->search_cabac.update = 1;
if (depth < MAX_DEPTH) { if (depth < MAX_DEPTH) {
// Add cost of cu_split_flag. // Add cost of cu_split_flag.
@ -792,9 +827,10 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth,
const cabac_ctx_t *ctx = &(state->search_cabac.ctx.part_size_model[0]); const cabac_ctx_t *ctx = &(state->search_cabac.ctx.part_size_model[0]);
split_cost += CTX_ENTROPY_FBITS(ctx, 0) * state->lambda; // NxN split_cost += CTX_ENTROPY_FBITS(ctx, 0) * state->lambda; // NxN
state->search_cabac.cur_ctx = ctx; state->search_cabac.cur_ctx = ctx;
FILE_BITS(CTX_ENTROPY_FBITS(ctx, 1), x, y, depth, "split"); FILE_BITS(CTX_ENTROPY_FBITS(ctx, 0), x, y, depth, "split");
CABAC_BIN(&state->search_cabac, 1, "split_search"); CABAC_BIN(&state->search_cabac, 0, "split_search");
} }
state->search_cabac.update = 0;
// If skip mode was selected for the block, skip further search. // If skip mode was selected for the block, skip further search.
// Skip mode means there's no coefficients in the block, so splitting // Skip mode means there's no coefficients in the block, so splitting
@ -1023,6 +1059,8 @@ static void copy_lcu_to_cu_data(const encoder_state_t * const state, int x_px, i
void kvz_search_lcu(encoder_state_t * const state, const int x, const int y, const yuv_t * const hor_buf, const yuv_t * const ver_buf) void kvz_search_lcu(encoder_state_t * const state, const int x, const int y, const yuv_t * const hor_buf, const yuv_t * const ver_buf)
{ {
if (bit_cost_file == NULL) bit_cost_file = fopen("bits_file.txt", "w"); if (bit_cost_file == NULL) bit_cost_file = fopen("bits_file.txt", "w");
memcpy(&state->search_cabac, &state->cabac, sizeof(cabac_data_t));
state->search_cabac.only_count = 1;
assert(x % LCU_WIDTH == 0); assert(x % LCU_WIDTH == 0);
assert(y % LCU_WIDTH == 0); assert(y % LCU_WIDTH == 0);