[intra] count mts and tr_skip bits

This commit is contained in:
Joose Sainio 2022-05-13 15:04:03 +03:00
parent 804bf3afcb
commit 418c65cbf1
4 changed files with 120 additions and 104 deletions

View file

@ -47,9 +47,9 @@
#include "tables.h"
#include "videoframe.h"
static bool is_mts_allowed(encoder_state_t * const state, cu_info_t *const pred_cu)
bool uvg_is_mts_allowed(encoder_state_t * const state, cu_info_t *const pred_cu)
{
uint32_t ts_max_size = 1 << 2; //cu.cs->sps->getLog2MaxTransformSkipBlockSize();
uint32_t ts_max_size = 1 << state->encoder_control->cfg.trskip_max_size;
const uint32_t max_size = 32; // CU::isIntra(cu) ? MTS_INTRA_MAX_CU_SIZE : MTS_INTER_MAX_CU_SIZE;
const uint32_t cu_width = LCU_WIDTH >> pred_cu->depth;
const uint32_t cu_height = LCU_WIDTH >> pred_cu->depth;
@ -61,6 +61,7 @@ static bool is_mts_allowed(encoder_state_t * const state, cu_info_t *const pred_
//mts_allowed &= !cu.ispMode; // ISP_TODO: Uncomment this when ISP is implemented.
//mts_allowed &= !cu.sbtInfo;
mts_allowed &= !(pred_cu->bdpcmMode && cu_width <= ts_max_size && cu_height <= ts_max_size);
mts_allowed &= pred_cu->tr_idx != MTS_SKIP && !pred_cu->violates_mts_coeff_constraint && pred_cu->mts_last_scan_pos;
return mts_allowed;
}
@ -71,7 +72,7 @@ static void encode_mts_idx(encoder_state_t * const state,
//TransformUnit &tu = *cu.firstTU;
int mts_idx = pred_cu->tr_idx;
if (is_mts_allowed(state, (cu_info_t* const )pred_cu) && mts_idx != MTS_SKIP
if (uvg_is_mts_allowed(state, (cu_info_t* const )pred_cu) && mts_idx != MTS_SKIP
&& !pred_cu->violates_mts_coeff_constraint
&& pred_cu->mts_last_scan_pos
&& pred_cu->lfnst_idx == 0
@ -718,7 +719,7 @@ static void encode_transform_coeff(encoder_state_t * const state,
if ((cur_cu->type == CU_INTRA || tr_depth > 0 || cb_flag_u || cb_flag_v) && !only_chroma) {
cabac->cur_ctx = &(cabac->ctx.qt_cbf_model_luma[0]);
CABAC_BIN(cabac, cb_flag_y, "cbf_luma");
// printf("%hu %hu %d %d\n", cabac->ctx.qt_cbf_model_luma[0].state[0], cabac->ctx.qt_cbf_model_luma[0].state[1], x, y);
printf("%hu %hu %d %d\n", cabac->ctx.qt_cbf_model_luma[0].state[0], cabac->ctx.qt_cbf_model_luma[0].state[1], x, y);
}
if (cb_flag_y | cb_flag_u | cb_flag_v) {

View file

@ -40,7 +40,9 @@
#include "encoderstate.h"
#include "global.h"
void uvg_encode_coding_tree(encoder_state_t * const state,
bool uvg_is_mts_allowed(encoder_state_t* const state, cu_info_t* const pred_cu);
void kvz_encode_coding_tree(encoder_state_t * const state,
uint16_t x_ctb,
uint16_t y_ctb,
uint8_t depth,

View file

@ -324,38 +324,23 @@ double uvg_cu_rd_cost_luma(const encoder_state_t *const state,
return sum + tr_tree_bits * state->lambda;
}
if (cabac->update && tr_cu->tr_depth == tr_cu->depth && !skip_residual_coding) {
// Because these need to be coded before the luma cbf they also need to be counted
// before the cabac state changes. However, since this branch is only executed when
// calculating the last RD cost it is not problem to include the chroma cbf costs in
// luma, because the chroma cost is calculated right after the luma cost.
// However, if we have different tr_depth, the bits cannot be written in correct
// order anyways so do not touch the chroma cbf here.
if (state->encoder_control->chroma_format != UVG_CSP_400) {
cabac_ctx_t* cr_ctx = &(cabac->ctx.qt_cbf_model_cb[0]);
cabac->cur_ctx = cr_ctx;
int u_is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_U);
int v_is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_V);
CABAC_FBITS_UPDATE(cabac, cr_ctx, u_is_set, tr_tree_bits, "cbf_cb_search");
cr_ctx = &(cabac->ctx.qt_cbf_model_cr[u_is_set]);
CABAC_FBITS_UPDATE(cabac, cr_ctx, v_is_set, tr_tree_bits, "cbf_cb_search");
}
}
// Add transform_tree cbf_luma bit cost.
const int is_tr_split = tr_cu->tr_depth - tr_cu->depth;
int is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_Y);
if (pred_cu->type == CU_INTRA ||
is_tr_split ||
cbf_is_set(tr_cu->cbf, depth, COLOR_U) ||
cbf_is_set(tr_cu->cbf, depth, COLOR_V))
{
cabac_ctx_t *ctx = &(cabac->ctx.qt_cbf_model_luma[0]);
int is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_Y);
CABAC_FBITS_UPDATE(cabac, ctx, is_set, tr_tree_bits, "cbf_y_search");
}
if (is_set && state->encoder_control->cfg.trskip_enable && width <= (1 << state->encoder_control->cfg.trskip_max_size)) {
CABAC_FBITS_UPDATE(cabac, &cabac->ctx.transform_skip_model_luma, tr_cu->tr_idx == MTS_SKIP, tr_tree_bits, "transform_skip_flag");
}
// SSD between reconstruction and original
int ssd = 0;
if (!state->encoder_control->cfg.lossless) {
@ -555,7 +540,10 @@ static double cu_rd_cost_tr_split_accurate(const encoder_state_t* const state,
width);
}
{
if(cb_flag_y){
if (state->encoder_control->cfg.trskip_enable && width <= (1 << state->encoder_control->cfg.trskip_max_size)) {
CABAC_FBITS_UPDATE(cabac, &cabac->ctx.transform_skip_model_luma, tr_cu->tr_idx == MTS_SKIP, tr_tree_bits, "transform_skip_flag");
}
int8_t luma_scan_mode = uvg_get_scan_order(pred_cu->type, pred_cu->intra.mode, depth);
const coeff_t* coeffs = &lcu->coeff.y[xy_to_zorder(LCU_WIDTH, x_px, y_px)];
@ -596,6 +584,20 @@ static double cu_rd_cost_tr_split_accurate(const encoder_state_t* const state,
coeff_bits += uvg_get_coeff_cost(state, &lcu->coeff.joint_uv[index], width, 2, scan_order, 0);
}
}
if (kvz_is_mts_allowed(state, tr_cu)) {
bool symbol = tr_cu->tr_idx != 0;
int ctx_idx = 0;
CABAC_FBITS_UPDATE(cabac, &state->search_cabac.ctx.mts_idx_model[ctx_idx], symbol, tr_tree_bits, "mts_idx");
ctx_idx++;
for (int i = 0; i < 3 && symbol; i++, ctx_idx++)
{
symbol = tr_cu->tr_idx > i + MTS_DST7_DST7 ? 1 : 0;
CABAC_FBITS_UPDATE(cabac, &state->search_cabac.ctx.mts_idx_model[ctx_idx], symbol, tr_tree_bits, "mts_idx");
}
}
double bits = tr_tree_bits + coeff_bits;
return luma_ssd * UVG_LUMA_MULT + chroma_ssd * UVG_CHROMA_MULT + bits * state->lambda;

View file

@ -381,7 +381,7 @@ static double search_intra_trdepth(
}
const int mts_start = trafo;
//TODO: height
if(state->encoder_control->cfg.trskip_enable && width <= (1 << state->encoder_control->cfg.trskip_max_size) /*&& height == 4*/) {
if (state->encoder_control->cfg.trskip_enable && width <= (1 << state->encoder_control->cfg.trskip_max_size) /*&& height == 4*/) {
num_transforms = MAX(num_transforms, 2);
}
pred_cu->intra.mode_chroma = -1;
@ -424,14 +424,15 @@ static double search_intra_trdepth(
lcu);
// TODO: Not sure if this should be 0 or 1 but at least seems to work with 1
derive_mts_constraints(pred_cu, lcu, depth, lcu_px);
if (pred_cu->tr_idx > 1)
{
derive_mts_constraints(pred_cu, lcu, depth, lcu_px);
if (pred_cu->violates_mts_coeff_constraint || !pred_cu->mts_last_scan_pos)
{
continue;
}
}
if (pred_cu->lfnst_idx > 0) {
// Temp constraints. Updating the actual pred_cu constraints here will break things later
bool constraints[2] = { pred_cu->violates_lfnst_constrained[0],
@ -440,73 +441,86 @@ static double search_intra_trdepth(
if (constraints[0] || !constraints[1]) {
continue;
}
}
double rd_cost = uvg_cu_rd_cost_luma(state, lcu_px.x, lcu_px.y, depth, pred_cu, lcu);
double mts_bits = 0;
if (num_transforms > 1 && trafo != MTS_SKIP && width <= 32 /*&& height <= 32*/
&& !pred_cu->violates_mts_coeff_constraint && pred_cu->mts_last_scan_pos) {
double rd_cost = uvg_cu_rd_cost_luma(state, lcu_px.x, lcu_px.y, depth, pred_cu, lcu);
//if (reconstruct_chroma) {
// rd_cost += uvg_cu_rd_cost_chroma(state, lcu_px.x, lcu_px.y, depth, pred_cu, lcu);
//}
bool symbol = trafo != 0;
int ctx_idx = 0;
mts_bits += CTX_ENTROPY_FBITS(&state->search_cabac.ctx.mts_idx_model[ctx_idx], symbol);
// TODO: there is an error in this cost calculation. This will be fixed when merged into current master
// This is compared to the previous best, which may have chroma cost included
if (rd_cost < best_rd_cost) {
best_rd_cost = rd_cost;
best_lfnst_idx = pred_cu->lfnst_idx;
best_tr_idx = pred_cu->tr_idx;
if (best_tr_idx == MTS_SKIP) break; // Very unlikely that further search is necessary if skip seems best option
}
} // end mts index loop (tr_idx)
if (reconstruct_chroma) {
int8_t luma_mode = pred_cu->intra.mode;
pred_cu->intra.mode = -1;
pred_cu->intra.mode_chroma = chroma_mode;
pred_cu->joint_cb_cr = 4; // TODO: Maybe check the jccr mode here also but holy shit is the interface of search_intra_rdo bad currently
uvg_intra_recon_cu(state,
x_px, y_px,
depth, search_data,
pred_cu,
lcu);
best_rd_cost += uvg_cu_rd_cost_chroma(state, lcu_px.x, lcu_px.y, depth, pred_cu, lcu);
pred_cu->intra.mode = luma_mode;
ctx_idx++;
for (int i = 0; i < 3 && symbol; i++, ctx_idx++)
{
symbol = trafo > i + MTS_DST7_DST7 ? 1 : 0;
mts_bits += CTX_ENTROPY_FBITS(&state->search_cabac.ctx.mts_idx_model[ctx_idx], symbol);
}
// Check lfnst constraints for chroma
if (pred_cu->lfnst_idx > 0) {
// Temp constraints. Updating the actual pred_cu constraints here will break things later
bool constraints[2] = { pred_cu->violates_lfnst_constrained[1],
pred_cu->lfnst_last_scan_pos };
derive_lfnst_constraints(pred_cu, lcu, depth, COLOR_U, lcu_px, constraints);
if (constraints[0] || !constraints[1]) {
best_lfnst_idx = 0;
continue;
}
derive_lfnst_constraints(pred_cu, lcu, depth, COLOR_V, lcu_px, constraints);
if (constraints[0] || !constraints[1]) {
best_lfnst_idx = 0;
continue;
rd_cost += mts_bits * state->frame->lambda;
// TODO: there is an error in this cost calculation. This will be fixed when merged into current master
// This is compared to the previous best, which may have chroma cost included
if (rd_cost < best_rd_cost) {
best_rd_cost = rd_cost;
best_lfnst_idx = pred_cu->lfnst_idx;
best_tr_idx = pred_cu->tr_idx;
if (best_tr_idx == MTS_SKIP) break; // Very unlikely that further search is necessary if skip seems best option
}
} // end mts index loop (tr_idx)
if (reconstruct_chroma) {
int8_t luma_mode = pred_cu->intra.mode;
pred_cu->intra.mode = -1;
pred_cu->intra.mode_chroma = chroma_mode;
pred_cu->joint_cb_cr = 4; // TODO: Maybe check the jccr mode here also but holy shit is the interface of search_intra_rdo bad currently
uvg_intra_recon_cu(state,
x_px, y_px,
depth, search_data,
pred_cu,
lcu);
best_rd_cost += uvg_cu_rd_cost_chroma(state, lcu_px.x, lcu_px.y, depth, pred_cu, lcu);
pred_cu->intra.mode = luma_mode;
// Check lfnst constraints for chroma
if (pred_cu->lfnst_idx > 0) {
// Temp constraints. Updating the actual pred_cu constraints here will break things later
bool constraints[2] = { pred_cu->violates_lfnst_constrained[1],
pred_cu->lfnst_last_scan_pos };
derive_lfnst_constraints(pred_cu, lcu, depth, COLOR_U, lcu_px, constraints);
if (constraints[0] || !constraints[1]) {
best_lfnst_idx = 0;
continue;
}
derive_lfnst_constraints(pred_cu, lcu, depth, COLOR_V, lcu_px, constraints);
if (constraints[0] || !constraints[1]) {
best_lfnst_idx = 0;
continue;
}
}
}
if (best_tr_idx == MTS_SKIP) break; // Very unlikely that further search is necessary if skip seems best option
} // end lfnst_index loop
pred_cu->tr_skip = best_tr_idx == MTS_SKIP;
pred_cu->tr_idx = best_tr_idx;
pred_cu->lfnst_idx = best_lfnst_idx;
nosplit_cost += best_rd_cost;
// Early stop condition for the recursive search.
// If the cost of any 1/4th of the transform is already larger than the
// whole transform, assume that splitting further is a bad idea.
if (nosplit_cost >= cost_treshold) {
return nosplit_cost;
}
if (best_tr_idx == MTS_SKIP) break; // Very unlikely that further search is necessary if skip seems best option
} // end lfnst_index loop
pred_cu->tr_skip = best_tr_idx == MTS_SKIP;
pred_cu->tr_idx = best_tr_idx;
pred_cu->lfnst_idx = best_lfnst_idx;
nosplit_cost += best_rd_cost;
// Early stop condition for the recursive search.
// If the cost of any 1/4th of the transform is already larger than the
// whole transform, assume that splitting further is a bad idea.
if (nosplit_cost >= cost_treshold) {
return nosplit_cost;
}
nosplit_cbf = pred_cu->cbf;
nosplit_cbf = pred_cu->cbf;
uvg_pixels_blit(lcu->rec.y, nosplit_pixels.y, width, width, LCU_WIDTH, width);
if (reconstruct_chroma) {
uvg_pixels_blit(lcu->rec.u, nosplit_pixels.u, width_c, width_c, LCU_WIDTH_C, width_c);
uvg_pixels_blit(lcu->rec.v, nosplit_pixels.v, width_c, width_c, LCU_WIDTH_C, width_c);
uvg_pixels_blit(lcu->rec.y, nosplit_pixels.y, width, width, LCU_WIDTH, width);
if (reconstruct_chroma) {
uvg_pixels_blit(lcu->rec.u, nosplit_pixels.u, width_c, width_c, LCU_WIDTH_C, width_c);
uvg_pixels_blit(lcu->rec.v, nosplit_pixels.v, width_c, width_c, LCU_WIDTH_C, width_c);
}
}
}
@ -910,9 +924,9 @@ static double count_bits(
static int16_t search_intra_rough(
encoder_state_t * const state,
const cu_loc_t* const cu_loc,
kvz_pixel *orig,
uvg_pixel *orig,
int32_t origstride,
kvz_intra_references *refs,
uvg_intra_references *refs,
int log2_width,
int8_t *intra_preds,
intra_search_data_t* modes_out,
@ -924,23 +938,23 @@ static int16_t search_intra_rough(
int_fast8_t width = 1 << log2_width;
// cost_pixel_nxn_func *satd_func = kvz_pixels_get_satd_func(width);
// cost_pixel_nxn_func *sad_func = kvz_pixels_get_sad_func(width);
cost_pixel_nxn_multi_func *satd_dual_func = kvz_pixels_get_satd_dual_func(width);
cost_pixel_nxn_multi_func *sad_dual_func = kvz_pixels_get_sad_dual_func(width);
bool mode_checked[KVZ_NUM_INTRA_MODES] = {0};
double costs[KVZ_NUM_INTRA_MODES];
cost_pixel_nxn_multi_func *satd_dual_func = uvg_pixels_get_satd_dual_func(width);
cost_pixel_nxn_multi_func *sad_dual_func = uvg_pixels_get_sad_dual_func(width);
bool mode_checked[UVG_NUM_INTRA_MODES] = {0};
double costs[UVG_NUM_INTRA_MODES];
// const kvz_config *cfg = &state->encoder_control->cfg;
// const bool filter_boundary = !(cfg->lossless && cfg->implicit_rdpcm);
// Temporary block arrays
kvz_pixel _preds[PARALLEL_BLKS * 32 * 32 + SIMD_ALIGNMENT];
uvg_pixel _preds[PARALLEL_BLKS * 32 * 32 + SIMD_ALIGNMENT];
pred_buffer preds = ALIGNED_POINTER(_preds, SIMD_ALIGNMENT);
kvz_pixel _orig_block[32 * 32 + SIMD_ALIGNMENT];
kvz_pixel *orig_block = ALIGNED_POINTER(_orig_block, SIMD_ALIGNMENT);
uvg_pixel _orig_block[32 * 32 + SIMD_ALIGNMENT];
uvg_pixel *orig_block = ALIGNED_POINTER(_orig_block, SIMD_ALIGNMENT);
// Store original block for SAD computation
kvz_pixels_blit(orig, orig_block, width, width, origstride, width);
uvg_pixels_blit(orig, orig_block, width, width, origstride, width);
int8_t modes_selected = 0;
// Note: get_cost and get_cost_dual may return negative costs.
@ -973,9 +987,9 @@ static int16_t search_intra_rough(
int offset = 4;
search_proxy.pred_cu.intra.mode = 0;
kvz_intra_predict(state, refs, &loc, COLOR_Y, preds[0], &search_proxy, NULL);
uvg_intra_predict(state, refs, &loc, COLOR_Y, preds[0], &search_proxy, NULL);
search_proxy.pred_cu.intra.mode = 1;
kvz_intra_predict(state, refs, &loc, COLOR_Y, preds[1], &search_proxy, NULL);
uvg_intra_predict(state, refs, &loc, COLOR_Y, preds[1], &search_proxy, NULL);
get_cost_dual(state, preds, orig_block, satd_dual_func, sad_dual_func, width, costs);
mode_checked[0] = true;
mode_checked[1] = true;
@ -1025,7 +1039,7 @@ static int16_t search_intra_rough(
for (int i = 0; i < PARALLEL_BLKS; ++i) {
if (mode + i * offset <= 66) {
search_proxy.pred_cu.intra.mode = mode + i*offset;
kvz_intra_predict(state, refs, &loc, COLOR_Y, preds[i], &search_proxy, NULL);
uvg_intra_predict(state, refs, &loc, COLOR_Y, preds[i], &search_proxy, NULL);
}
}
@ -1097,7 +1111,7 @@ static int16_t search_intra_rough(
for (int block = 0; block < PARALLEL_BLKS; ++block) {
search_proxy.pred_cu.intra.mode = modes_to_check[block + i];
kvz_intra_predict(state, refs, &loc, COLOR_Y, preds[block], &search_proxy, NULL);
uvg_intra_predict(state, refs, &loc, COLOR_Y, preds[block], &search_proxy, NULL);
}
@ -1765,10 +1779,7 @@ void uvg_search_cu_intra(
depth,
number_of_modes_to_search,
search_data,
lcu);
// Reset these
search_data[0].pred_cu.violates_mts_coeff_constraint = false;
search_data[0].pred_cu.mts_last_scan_pos = false;
lcu);
}
*mode_out = search_data[0];
}