Use correct context for calculating coeff costs for transform skip

This commit is contained in:
Joose Sainio 2021-06-07 13:06:03 +03:00
parent 4594bf0ca8
commit cfffd7166c
8 changed files with 43 additions and 23 deletions

View file

@ -89,10 +89,10 @@ static void encode_mts_idx(encoder_state_t * const state,
}
}
static void encode_ts_residual(encoder_state_t* const state,
void kvz_encode_ts_residual(encoder_state_t* const state,
cabac_data_t* const cabac,
const coeff_t* coeff,
uint8_t width,
uint32_t width,
uint8_t type,
int8_t scan_mode) {
//const encoder_control_t * const encoder = state->encoder_control;
@ -397,7 +397,7 @@ static void encode_transform_unit(encoder_state_t * const state,
}
if(cur_pu->tr_idx == MTS_SKIP) {
encode_ts_residual(state, cabac, coeff_y, width, 0, scan_idx);
kvz_encode_ts_residual(state, cabac, coeff_y, width, 0, scan_idx);
}
else {
kvz_encode_coeff_nxn(state,

View file

@ -33,6 +33,13 @@ void kvz_encode_coding_tree(encoder_state_t * const state,
uint16_t y_ctb,
uint8_t depth);
void kvz_encode_ts_residual(encoder_state_t* const state,
cabac_data_t* const cabac,
const coeff_t* coeff,
uint32_t width,
uint8_t type,
int8_t scan_mode);
void kvz_encode_mvd(encoder_state_t * const state,
cabac_data_t *cabac,
int32_t mvd_hor,

View file

@ -236,7 +236,8 @@ static INLINE uint32_t get_coeff_cabac_cost(
const coeff_t *coeff,
int32_t width,
int32_t type,
int8_t scan_mode)
int8_t scan_mode,
int8_t tr_skip)
{
// Make sure there are coeffs present
bool found = false;
@ -261,14 +262,24 @@ static INLINE uint32_t get_coeff_cabac_cost(
// Execute the coding function.
// It is safe to drop the const modifier since state won't be modified
// when cabac.only_count is set.
kvz_encode_coeff_nxn((encoder_state_t*) state,
&cabac_copy,
coeff,
width,
type,
scan_mode,
NULL,
false);
if(!tr_skip) {
kvz_encode_coeff_nxn((encoder_state_t*) state,
&cabac_copy,
coeff,
width,
type,
scan_mode,
NULL,
false);
}
else {
kvz_encode_ts_residual(state,
&cabac_copy,
coeff,
width,
type,
scan_mode);
}
return (23 - cabac_copy.bits_left) + (cabac_copy.num_buffered_bytes << 3);
}
@ -313,13 +324,14 @@ uint32_t kvz_get_coeff_cost(const encoder_state_t * const state,
const coeff_t *coeff,
int32_t width,
int32_t type,
int8_t scan_mode)
int8_t scan_mode,
int8_t tr_skip)
{
uint8_t save_cccs = state->encoder_control->cfg.fastrd_sampling_on;
uint8_t check_accuracy = state->encoder_control->cfg.fastrd_accuracy_check_on;
if (state->qp < state->encoder_control->cfg.fast_residual_cost_limit &&
state->qp < MAX_FAST_COEFF_COST_QP) {
state->qp < MAX_FAST_COEFF_COST_QP && !tr_skip) {
// TODO: do we need to assert(0) out of the fast-estimation branch if we
// are to save block costs, or should we just warn about it somewhere
// earlier (configuration validation I guess)?
@ -330,13 +342,13 @@ uint32_t kvz_get_coeff_cost(const encoder_state_t * const state,
uint64_t weights = kvz_fast_coeff_get_weights(state);
uint32_t fast_cost = kvz_fast_coeff_cost(coeff, width, weights);
if (check_accuracy) {
uint32_t ccc = get_coeff_cabac_cost(state, coeff, width, type, scan_mode);
uint32_t ccc = get_coeff_cabac_cost(state, coeff, width, type, scan_mode, tr_skip);
save_accuracy(state->qp, ccc, fast_cost);
}
return fast_cost;
}
} else {
uint32_t ccc = get_coeff_cabac_cost(state, coeff, width, type, scan_mode);
uint32_t ccc = get_coeff_cabac_cost(state, coeff, width, type, scan_mode, tr_skip);
if (save_cccs) {
save_ccc(state->qp, coeff, width * width, ccc);
}

View file

@ -51,7 +51,8 @@ uint32_t kvz_get_coeff_cost(const encoder_state_t * const state,
const coeff_t *coeff,
int32_t width,
int32_t type,
int8_t scan_mode);
int8_t scan_mode,
int8_t tr_skip);
int32_t kvz_get_ic_rate(encoder_state_t *state, uint32_t abs_level, uint16_t ctx_num_gt1, uint16_t ctx_num_gt2, uint16_t ctx_num_par,
uint16_t abs_go_rice, uint32_t reg_bins, int8_t type, int use_limited_prefix_length);

View file

@ -299,7 +299,7 @@ double kvz_cu_rd_cost_luma(const encoder_state_t *const state,
int8_t luma_scan_mode = kvz_get_scan_order(pred_cu->type, pred_cu->intra.mode, depth);
const coeff_t *coeffs = &lcu->coeff.y[xy_to_zorder(LCU_WIDTH, x_px, y_px)];
coeff_bits += kvz_get_coeff_cost(state, coeffs, width, 0, luma_scan_mode);
coeff_bits += kvz_get_coeff_cost(state, coeffs, width, 0, luma_scan_mode, pred_cu->tr_idx == MTS_SKIP);
}
double bits = tr_tree_bits + coeff_bits;
@ -371,8 +371,8 @@ double kvz_cu_rd_cost_chroma(const encoder_state_t *const state,
int8_t scan_order = kvz_get_scan_order(pred_cu->type, pred_cu->intra.mode_chroma, depth);
const int index = xy_to_zorder(LCU_WIDTH_C, lcu_px.x, lcu_px.y);
coeff_bits += kvz_get_coeff_cost(state, &lcu->coeff.u[index], width, 2, scan_order);
coeff_bits += kvz_get_coeff_cost(state, &lcu->coeff.v[index], width, 2, scan_order);
coeff_bits += kvz_get_coeff_cost(state, &lcu->coeff.u[index], width, 2, scan_order, 0);
coeff_bits += kvz_get_coeff_cost(state, &lcu->coeff.v[index], width, 2, scan_order, 0);
}
double bits = tr_tree_bits + coeff_bits;

View file

@ -693,7 +693,7 @@ int kvz_quantize_residual_avx2(encoder_state_t *const state,
// Quantize coeffs. (coeff -> coeff_out)
if (state->encoder_control->cfg.rdoq_enable &&
(width > 4 || !state->encoder_control->cfg.rdoq_skip))
(width > 4 || !state->encoder_control->cfg.rdoq_skip) && !use_trskip)
{
int8_t tr_depth = cur_cu->tr_depth - cur_cu->depth;
tr_depth += (cur_cu->part_size == SIZE_NxN ? 1 : 0);

View file

@ -239,7 +239,7 @@ int kvz_quantize_residual_generic(encoder_state_t *const state,
// Quantize coeffs. (coeff -> coeff_out)
if (state->encoder_control->cfg.rdoq_enable &&
(width > 4 || !state->encoder_control->cfg.rdoq_skip))
(width > 4 || !state->encoder_control->cfg.rdoq_skip) && !use_trskip)
{
int8_t tr_depth = cur_cu->tr_depth - cur_cu->depth;
tr_depth += (cur_cu->part_size == SIZE_NxN ? 1 : 0);

View file

@ -272,7 +272,7 @@ int kvz_quantize_residual_trskip(
1, in_stride, 4,
ref_in, pred_in, skip.rec, skip.coeff, false, lmcs_chroma_adj);
skip.cost = kvz_pixels_calc_ssd(ref_in, skip.rec, in_stride, 4, 4);
skip.cost += kvz_get_coeff_cost(state, skip.coeff, 4, 0, scan_order) * bit_cost;
skip.cost += kvz_get_coeff_cost(state, skip.coeff, 4, 0, scan_order, 1) * bit_cost;
/* if (noskip.cost <= skip.cost) {
*trskip_out = 0;