accurate bit cost calculation when using transform skip

This commit is contained in:
Joose Sainio 2021-12-20 09:36:23 +02:00
parent a038ccc19a
commit 243e45f07e
4 changed files with 153 additions and 27 deletions

View file

@ -106,8 +106,8 @@ void kvz_cabac_start(cabac_data_t * const data)
void kvz_cabac_encode_bin(cabac_data_t * const data, const uint32_t bin_value)
{
uint32_t lps;
if (!(data)->only_count) bits_written += CTX_ENTROPY_FBITS((data)->cur_ctx, (bin_value));
lps = kvz_g_auc_lpst_table[CTX_STATE(data->cur_ctx)][(data->range >> 6) & 3];
data->range -= lps;
@ -577,6 +577,6 @@ uint32_t kvz_cabac_write_ep_ex_golomb(encoder_state_t * const state,
bins = ( (bins >> (num_bins >>1) ) << (num_bins >>1) ) | state->crypto_prev_pos;
}
}
kvz_cabac_encode_bins_ep(data, bins, num_bins);
CABAC_BINS_EP(data, bins, num_bins, "ep_ex_golomb");
return num_bins;
}

View file

@ -156,7 +156,6 @@ extern double bits_written;
#ifdef VERBOSE
#define CABAC_BIN(data, value, name) { \
uint32_t prev_state = (data)->cur_ctx->uc_state; \
if(!(data)->only_count) bits_written += CTX_ENTROPY_FBITS((data)->cur_ctx, (value));\
kvz_cabac_encode_bin((data), (value)); \
if(!(data)->only_count) printf("%s = %u, state = %u -> %u MPS = %u bits = %f\n", \
(name), (uint32_t)(value), prev_state, (data)->cur_ctx->uc_state, CTX_MPS((data)->cur_ctx), bits_written); }

View file

@ -685,6 +685,7 @@ static void encoder_state_worker_encode_lcu(void * opaque)
const uint64_t existing_bits = kvz_bitstream_tell(&state->stream);
//Encode SAO
state->cabac.update = 1;
if (encoder->cfg.sao_type) {
encode_sao(state, lcu->position.x, lcu->position.y, &frame->sao_luma[lcu->position.y * frame->width_in_lcu + lcu->position.x], &frame->sao_chroma[lcu->position.y * frame->width_in_lcu + lcu->position.x]);
}
@ -737,6 +738,7 @@ static void encoder_state_worker_encode_lcu(void * opaque)
kvz_crypto_delete(&state->crypto_hdl);
}
}
state->cabac.update = 0;
pthread_mutex_lock(&state->frame->rc_lock);
const uint32_t bits = kvz_bitstream_tell(&state->stream) - existing_bits;

View file

@ -299,30 +299,34 @@ double kvz_cu_rd_cost_luma(const encoder_state_t *const state,
return sum + tr_tree_bits * state->lambda;
}
if (cabac->update && tr_cu->tr_depth == tr_cu->depth) {
// Because these need to be coded before the luma cbf they also need to be counted
// before the cabac state changes. However, since this branch is only executed when
// calculating the last RD cost it is not problem to include the chroma cbf costs in
// luma, because the chroma cost is calculated right after the luma cost.
// However, if we have different tr_depth, the bits cannot be written in correct
// order anyways so do not touch the chroma cbf here.
if (state->encoder_control->chroma_format != KVZ_CSP_400) {
cabac_ctx_t* cr_ctx = &(cabac->ctx.qt_cbf_model_chroma[depth - tr_cu->depth]);
cabac->cur_ctx = cr_ctx;
int u_is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_U);
int v_is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_V);
CABAC_FBITS_UPDATE(cabac, cr_ctx, u_is_set, tr_tree_bits, "cbf_cb_search");
CABAC_FBITS_UPDATE(cabac, cr_ctx, v_is_set, tr_tree_bits, "cbf_cb_search");
}
}
// Add transform_tree cbf_luma bit cost.
const int is_tr_split = tr_cu->tr_depth - tr_cu->depth;
if (pred_cu->type == CU_INTRA ||
tr_depth > 0 ||
is_tr_split ||
cbf_is_set(tr_cu->cbf, depth, COLOR_U) ||
cbf_is_set(tr_cu->cbf, depth, COLOR_V))
{
cabac_ctx_t *ctx = &(cabac->ctx.qt_cbf_model_luma[!tr_depth]);
cabac_ctx_t *ctx = &(cabac->ctx.qt_cbf_model_luma[!is_tr_split]);
int is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_Y);
if (cabac->update && tr_cu->tr_depth == tr_cu->depth) {
// Because these need to be coded before the luma cbf they also need to be counted
// before the cabac state changes. However, since this branch is only executed when
// calculating the last RD cost it is not problem to include the chroma cbf costs in
// luma, because the chroma cost is calculated right after the luma cost.
// However, if we have different tr_depth, the bits cannot be written in correct
// order anyways so do not touch the chroma cbf here.
if (state->encoder_control->chroma_format != KVZ_CSP_400) {
cabac_ctx_t* cr_ctx = &(cabac->ctx.qt_cbf_model_chroma[tr_depth]);
cabac->cur_ctx = cr_ctx;
int u_is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_U);
int v_is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_V);
CABAC_FBITS_UPDATE(cabac, cr_ctx, u_is_set, tr_tree_bits, "cbf_cb_search");
CABAC_FBITS_UPDATE(cabac, cr_ctx, v_is_set, tr_tree_bits, "cbf_cb_search");
}
}
CABAC_FBITS_UPDATE(cabac, ctx, is_set, tr_tree_bits, "cbf_y_search");
*bit_cost += tr_tree_bits;
}
@ -390,7 +394,7 @@ double kvz_cu_rd_cost_chroma(const encoder_state_t *const state,
if (tr_cu->tr_depth > depth) {
int offset = LCU_WIDTH >> (depth + 1);
int sum = 0;
double sum = 0;
sum += kvz_cu_rd_cost_chroma(state, x_px, y_px, depth + 1, pred_cu, lcu, bit_cost);
sum += kvz_cu_rd_cost_chroma(state, x_px + offset, y_px, depth + 1, pred_cu, lcu, bit_cost);
@ -426,6 +430,126 @@ double kvz_cu_rd_cost_chroma(const encoder_state_t *const state,
return (double)ssd * CHROMA_MULT + bits * state->lambda;
}
static double cu_rd_cost_tr_split_accurate(const encoder_state_t* const state,
const int x_px, const int y_px, const int depth,
const cu_info_t* const pred_cu,
lcu_t* const lcu,
double* bit_cost) {
const int width = LCU_WIDTH >> depth;
// cur_cu is used for TU parameters.
cu_info_t* const tr_cu = LCU_GET_CU_AT_PX(lcu, x_px, y_px);
double coeff_bits = 0;
double tr_tree_bits = 0;
// Check that lcu is not in
assert(x_px >= 0 && x_px < LCU_WIDTH);
assert(y_px >= 0 && y_px < LCU_WIDTH);
const uint8_t tr_depth = tr_cu->tr_depth - depth;
const int cb_flag_u = cbf_is_set(tr_cu->cbf, depth, COLOR_U);
const int cb_flag_v = cbf_is_set(tr_cu->cbf, depth, COLOR_V);
cabac_data_t* cabac = (cabac_data_t*)&state->search_cabac;
// Add transform_tree split_transform_flag bit cost.
bool intra_split_flag = pred_cu->type == CU_INTRA && pred_cu->part_size == SIZE_NxN && depth == 3;
int max_tr_depth;
if (pred_cu->type == CU_INTRA) {
max_tr_depth = state->encoder_control->cfg.tr_depth_intra + intra_split_flag;
}
else {
max_tr_depth = state->encoder_control->tr_depth_inter;
}
if (width <= TR_MAX_WIDTH
&& width > TR_MIN_WIDTH
&& !intra_split_flag
&& MIN(tr_cu->tr_depth, depth) - tr_cu->depth < max_tr_depth)
{
cabac_ctx_t* ctx = &(cabac->ctx.trans_subdiv_model[5 - (6 - depth)]);
CABAC_FBITS_UPDATE(cabac, ctx, tr_depth > 0, tr_tree_bits, "tr_split_search");
}
if(state->encoder_control->chroma_format != KVZ_CSP_400) {
if(tr_cu->depth == depth || cbf_is_set(tr_cu->cbf, depth - 1, COLOR_U)) {
CABAC_FBITS_UPDATE(cabac, &(cabac->ctx.qt_cbf_model_chroma[depth - tr_cu->depth]), cb_flag_u, tr_tree_bits, "cbf_cb");
}
if(tr_cu->depth == depth || cbf_is_set(tr_cu->cbf, depth - 1, COLOR_V)) {
CABAC_FBITS_UPDATE(cabac, &(cabac->ctx.qt_cbf_model_chroma[depth - tr_cu->depth]), cb_flag_v, tr_tree_bits, "cbf_cr");
}
}
if (tr_depth > 0) {
int offset = LCU_WIDTH >> (depth + 1);
double sum = 0;
*bit_cost += tr_tree_bits;
sum += cu_rd_cost_tr_split_accurate(state, x_px, y_px, depth + 1, pred_cu, lcu, bit_cost);
sum += cu_rd_cost_tr_split_accurate(state, x_px + offset, y_px, depth + 1, pred_cu, lcu, bit_cost);
sum += cu_rd_cost_tr_split_accurate(state, x_px, y_px + offset, depth + 1, pred_cu, lcu, bit_cost);
sum += cu_rd_cost_tr_split_accurate(state, x_px + offset, y_px + offset, depth + 1, pred_cu, lcu, bit_cost);
return sum + tr_tree_bits * state->lambda;
}
const int cb_flag_y = cbf_is_set(tr_cu->cbf, depth, COLOR_Y) ;
// Add transform_tree cbf_luma bit cost.
const int is_tr_split = depth - tr_cu->depth;
if (pred_cu->type == CU_INTRA ||
is_tr_split ||
cb_flag_u ||
cb_flag_v)
{
cabac_ctx_t* ctx = &(cabac->ctx.qt_cbf_model_luma[!is_tr_split]);
CABAC_FBITS_UPDATE(cabac, ctx, cb_flag_y, tr_tree_bits, "cbf_y_search");
}
*bit_cost += tr_tree_bits;
// SSD between reconstruction and original
unsigned luma_ssd = 0;
if (!state->encoder_control->cfg.lossless) {
int index = y_px * LCU_WIDTH + x_px;
luma_ssd = kvz_pixels_calc_ssd(&lcu->ref.y[index], &lcu->rec.y[index],
LCU_WIDTH, LCU_WIDTH,
width);
}
{
int8_t luma_scan_mode = kvz_get_scan_order(pred_cu->type, pred_cu->intra.mode, depth);
const coeff_t* coeffs = &lcu->coeff.y[xy_to_zorder(LCU_WIDTH, x_px, y_px)];
coeff_bits += kvz_get_coeff_cost(state, coeffs, width, 0, luma_scan_mode);
}
unsigned chroma_ssd = 0;
if(state->encoder_control->chroma_format != KVZ_CSP_400 && x_px % 8 == 0 && y_px % 8 == 0) {
const vector2d_t lcu_px = { x_px / 2, y_px / 2 };
const int chroma_width = (depth <= MAX_DEPTH) ? LCU_WIDTH >> (depth + 1) : LCU_WIDTH >> depth;
if (!state->encoder_control->cfg.lossless) {
int index = lcu_px.y * LCU_WIDTH_C + lcu_px.x;
unsigned ssd_u = kvz_pixels_calc_ssd(&lcu->ref.u[index], &lcu->rec.u[index],
LCU_WIDTH_C, LCU_WIDTH_C,
chroma_width);
unsigned ssd_v = kvz_pixels_calc_ssd(&lcu->ref.v[index], &lcu->rec.v[index],
LCU_WIDTH_C, LCU_WIDTH_C,
chroma_width);
chroma_ssd = ssd_u + ssd_v;
}
{
int8_t scan_order = kvz_get_scan_order(pred_cu->type, pred_cu->intra.mode_chroma, depth);
const unsigned index = xy_to_zorder(LCU_WIDTH_C, lcu_px.x, lcu_px.y);
coeff_bits += kvz_get_coeff_cost(state, &lcu->coeff.u[index], chroma_width, 2, scan_order);
coeff_bits += kvz_get_coeff_cost(state, &lcu->coeff.v[index], chroma_width, 2, scan_order);
}
}
*bit_cost += coeff_bits;
double bits = tr_tree_bits + coeff_bits;
return luma_ssd * LUMA_MULT + chroma_ssd * CHROMA_MULT + bits * state->lambda;
}
// Return estimate of bits used to code prediction mode of cur_cu.
static double calc_mode_bits(const encoder_state_t *state,
@ -763,10 +887,10 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth,
cost = bits * state->lambda;
cost += kvz_cu_rd_cost_luma(state, x_local, y_local, depth, cur_cu, lcu, &bits);
if (state->encoder_control->chroma_format != KVZ_CSP_400) {
cost += kvz_cu_rd_cost_chroma(state, x_local, y_local, depth, cur_cu, lcu, & bits);
}
cost += cu_rd_cost_tr_split_accurate(state, x_local, y_local, depth, cur_cu, lcu, &bits);
//if (state->encoder_control->chroma_format != KVZ_CSP_400) {
// cost += kvz_cu_rd_cost_chroma(state, x_local, y_local, depth, cur_cu, lcu, & bits);
//}
FILE_BITS(bits, x, y, depth, "final rd bits");
@ -826,6 +950,7 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth,
cabac_ctx_t *ctx = &(state->search_cabac.ctx.part_size_model[0]);
CABAC_FBITS_UPDATE(&state->search_cabac, ctx, 0, split_bits, "split_search");
}
FILE_BITS(split_bits, x, y, depth, "split");
state->search_cabac.update = 0;
split_cost += split_bits * state->lambda;