[intra] count mts and tr_skip bits

This commit is contained in:
Joose Sainio 2022-05-13 15:04:03 +03:00
parent 804bf3afcb
commit 418c65cbf1
4 changed files with 120 additions and 104 deletions

View file

@ -47,9 +47,9 @@
#include "tables.h" #include "tables.h"
#include "videoframe.h" #include "videoframe.h"
static bool is_mts_allowed(encoder_state_t * const state, cu_info_t *const pred_cu) bool uvg_is_mts_allowed(encoder_state_t * const state, cu_info_t *const pred_cu)
{ {
uint32_t ts_max_size = 1 << 2; //cu.cs->sps->getLog2MaxTransformSkipBlockSize(); uint32_t ts_max_size = 1 << state->encoder_control->cfg.trskip_max_size;
const uint32_t max_size = 32; // CU::isIntra(cu) ? MTS_INTRA_MAX_CU_SIZE : MTS_INTER_MAX_CU_SIZE; const uint32_t max_size = 32; // CU::isIntra(cu) ? MTS_INTRA_MAX_CU_SIZE : MTS_INTER_MAX_CU_SIZE;
const uint32_t cu_width = LCU_WIDTH >> pred_cu->depth; const uint32_t cu_width = LCU_WIDTH >> pred_cu->depth;
const uint32_t cu_height = LCU_WIDTH >> pred_cu->depth; const uint32_t cu_height = LCU_WIDTH >> pred_cu->depth;
@ -61,6 +61,7 @@ static bool is_mts_allowed(encoder_state_t * const state, cu_info_t *const pred_
//mts_allowed &= !cu.ispMode; // ISP_TODO: Uncomment this when ISP is implemented. //mts_allowed &= !cu.ispMode; // ISP_TODO: Uncomment this when ISP is implemented.
//mts_allowed &= !cu.sbtInfo; //mts_allowed &= !cu.sbtInfo;
mts_allowed &= !(pred_cu->bdpcmMode && cu_width <= ts_max_size && cu_height <= ts_max_size); mts_allowed &= !(pred_cu->bdpcmMode && cu_width <= ts_max_size && cu_height <= ts_max_size);
mts_allowed &= pred_cu->tr_idx != MTS_SKIP && !pred_cu->violates_mts_coeff_constraint && pred_cu->mts_last_scan_pos;
return mts_allowed; return mts_allowed;
} }
@ -71,7 +72,7 @@ static void encode_mts_idx(encoder_state_t * const state,
//TransformUnit &tu = *cu.firstTU; //TransformUnit &tu = *cu.firstTU;
int mts_idx = pred_cu->tr_idx; int mts_idx = pred_cu->tr_idx;
if (is_mts_allowed(state, (cu_info_t* const )pred_cu) && mts_idx != MTS_SKIP if (uvg_is_mts_allowed(state, (cu_info_t* const )pred_cu) && mts_idx != MTS_SKIP
&& !pred_cu->violates_mts_coeff_constraint && !pred_cu->violates_mts_coeff_constraint
&& pred_cu->mts_last_scan_pos && pred_cu->mts_last_scan_pos
&& pred_cu->lfnst_idx == 0 && pred_cu->lfnst_idx == 0
@ -718,7 +719,7 @@ static void encode_transform_coeff(encoder_state_t * const state,
if ((cur_cu->type == CU_INTRA || tr_depth > 0 || cb_flag_u || cb_flag_v) && !only_chroma) { if ((cur_cu->type == CU_INTRA || tr_depth > 0 || cb_flag_u || cb_flag_v) && !only_chroma) {
cabac->cur_ctx = &(cabac->ctx.qt_cbf_model_luma[0]); cabac->cur_ctx = &(cabac->ctx.qt_cbf_model_luma[0]);
CABAC_BIN(cabac, cb_flag_y, "cbf_luma"); CABAC_BIN(cabac, cb_flag_y, "cbf_luma");
// printf("%hu %hu %d %d\n", cabac->ctx.qt_cbf_model_luma[0].state[0], cabac->ctx.qt_cbf_model_luma[0].state[1], x, y); printf("%hu %hu %d %d\n", cabac->ctx.qt_cbf_model_luma[0].state[0], cabac->ctx.qt_cbf_model_luma[0].state[1], x, y);
} }
if (cb_flag_y | cb_flag_u | cb_flag_v) { if (cb_flag_y | cb_flag_u | cb_flag_v) {

View file

@ -40,7 +40,9 @@
#include "encoderstate.h" #include "encoderstate.h"
#include "global.h" #include "global.h"
void uvg_encode_coding_tree(encoder_state_t * const state, bool uvg_is_mts_allowed(encoder_state_t* const state, cu_info_t* const pred_cu);
void kvz_encode_coding_tree(encoder_state_t * const state,
uint16_t x_ctb, uint16_t x_ctb,
uint16_t y_ctb, uint16_t y_ctb,
uint8_t depth, uint8_t depth,

View file

@ -324,38 +324,23 @@ double uvg_cu_rd_cost_luma(const encoder_state_t *const state,
return sum + tr_tree_bits * state->lambda; return sum + tr_tree_bits * state->lambda;
} }
if (cabac->update && tr_cu->tr_depth == tr_cu->depth && !skip_residual_coding) {
// Because these need to be coded before the luma cbf they also need to be counted
// before the cabac state changes. However, since this branch is only executed when
// calculating the last RD cost it is not problem to include the chroma cbf costs in
// luma, because the chroma cost is calculated right after the luma cost.
// However, if we have different tr_depth, the bits cannot be written in correct
// order anyways so do not touch the chroma cbf here.
if (state->encoder_control->chroma_format != UVG_CSP_400) {
cabac_ctx_t* cr_ctx = &(cabac->ctx.qt_cbf_model_cb[0]);
cabac->cur_ctx = cr_ctx;
int u_is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_U);
int v_is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_V);
CABAC_FBITS_UPDATE(cabac, cr_ctx, u_is_set, tr_tree_bits, "cbf_cb_search");
cr_ctx = &(cabac->ctx.qt_cbf_model_cr[u_is_set]);
CABAC_FBITS_UPDATE(cabac, cr_ctx, v_is_set, tr_tree_bits, "cbf_cb_search");
}
}
// Add transform_tree cbf_luma bit cost. // Add transform_tree cbf_luma bit cost.
const int is_tr_split = tr_cu->tr_depth - tr_cu->depth; const int is_tr_split = tr_cu->tr_depth - tr_cu->depth;
int is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_Y);
if (pred_cu->type == CU_INTRA || if (pred_cu->type == CU_INTRA ||
is_tr_split || is_tr_split ||
cbf_is_set(tr_cu->cbf, depth, COLOR_U) || cbf_is_set(tr_cu->cbf, depth, COLOR_U) ||
cbf_is_set(tr_cu->cbf, depth, COLOR_V)) cbf_is_set(tr_cu->cbf, depth, COLOR_V))
{ {
cabac_ctx_t *ctx = &(cabac->ctx.qt_cbf_model_luma[0]); cabac_ctx_t *ctx = &(cabac->ctx.qt_cbf_model_luma[0]);
int is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_Y);
CABAC_FBITS_UPDATE(cabac, ctx, is_set, tr_tree_bits, "cbf_y_search"); CABAC_FBITS_UPDATE(cabac, ctx, is_set, tr_tree_bits, "cbf_y_search");
} }
if (is_set && state->encoder_control->cfg.trskip_enable && width <= (1 << state->encoder_control->cfg.trskip_max_size)) {
CABAC_FBITS_UPDATE(cabac, &cabac->ctx.transform_skip_model_luma, tr_cu->tr_idx == MTS_SKIP, tr_tree_bits, "transform_skip_flag");
}
// SSD between reconstruction and original // SSD between reconstruction and original
int ssd = 0; int ssd = 0;
if (!state->encoder_control->cfg.lossless) { if (!state->encoder_control->cfg.lossless) {
@ -555,7 +540,10 @@ static double cu_rd_cost_tr_split_accurate(const encoder_state_t* const state,
width); width);
} }
{ if(cb_flag_y){
if (state->encoder_control->cfg.trskip_enable && width <= (1 << state->encoder_control->cfg.trskip_max_size)) {
CABAC_FBITS_UPDATE(cabac, &cabac->ctx.transform_skip_model_luma, tr_cu->tr_idx == MTS_SKIP, tr_tree_bits, "transform_skip_flag");
}
int8_t luma_scan_mode = uvg_get_scan_order(pred_cu->type, pred_cu->intra.mode, depth); int8_t luma_scan_mode = uvg_get_scan_order(pred_cu->type, pred_cu->intra.mode, depth);
const coeff_t* coeffs = &lcu->coeff.y[xy_to_zorder(LCU_WIDTH, x_px, y_px)]; const coeff_t* coeffs = &lcu->coeff.y[xy_to_zorder(LCU_WIDTH, x_px, y_px)];
@ -596,6 +584,20 @@ static double cu_rd_cost_tr_split_accurate(const encoder_state_t* const state,
coeff_bits += uvg_get_coeff_cost(state, &lcu->coeff.joint_uv[index], width, 2, scan_order, 0); coeff_bits += uvg_get_coeff_cost(state, &lcu->coeff.joint_uv[index], width, 2, scan_order, 0);
} }
} }
if (kvz_is_mts_allowed(state, tr_cu)) {
bool symbol = tr_cu->tr_idx != 0;
int ctx_idx = 0;
CABAC_FBITS_UPDATE(cabac, &state->search_cabac.ctx.mts_idx_model[ctx_idx], symbol, tr_tree_bits, "mts_idx");
ctx_idx++;
for (int i = 0; i < 3 && symbol; i++, ctx_idx++)
{
symbol = tr_cu->tr_idx > i + MTS_DST7_DST7 ? 1 : 0;
CABAC_FBITS_UPDATE(cabac, &state->search_cabac.ctx.mts_idx_model[ctx_idx], symbol, tr_tree_bits, "mts_idx");
}
}
double bits = tr_tree_bits + coeff_bits; double bits = tr_tree_bits + coeff_bits;
return luma_ssd * UVG_LUMA_MULT + chroma_ssd * UVG_CHROMA_MULT + bits * state->lambda; return luma_ssd * UVG_LUMA_MULT + chroma_ssd * UVG_CHROMA_MULT + bits * state->lambda;

View file

@ -381,7 +381,7 @@ static double search_intra_trdepth(
} }
const int mts_start = trafo; const int mts_start = trafo;
//TODO: height //TODO: height
if(state->encoder_control->cfg.trskip_enable && width <= (1 << state->encoder_control->cfg.trskip_max_size) /*&& height == 4*/) { if (state->encoder_control->cfg.trskip_enable && width <= (1 << state->encoder_control->cfg.trskip_max_size) /*&& height == 4*/) {
num_transforms = MAX(num_transforms, 2); num_transforms = MAX(num_transforms, 2);
} }
pred_cu->intra.mode_chroma = -1; pred_cu->intra.mode_chroma = -1;
@ -424,14 +424,15 @@ static double search_intra_trdepth(
lcu); lcu);
// TODO: Not sure if this should be 0 or 1 but at least seems to work with 1 // TODO: Not sure if this should be 0 or 1 but at least seems to work with 1
derive_mts_constraints(pred_cu, lcu, depth, lcu_px);
if (pred_cu->tr_idx > 1) if (pred_cu->tr_idx > 1)
{ {
derive_mts_constraints(pred_cu, lcu, depth, lcu_px);
if (pred_cu->violates_mts_coeff_constraint || !pred_cu->mts_last_scan_pos) if (pred_cu->violates_mts_coeff_constraint || !pred_cu->mts_last_scan_pos)
{ {
continue; continue;
} }
} }
if (pred_cu->lfnst_idx > 0) { if (pred_cu->lfnst_idx > 0) {
// Temp constraints. Updating the actual pred_cu constraints here will break things later // Temp constraints. Updating the actual pred_cu constraints here will break things later
bool constraints[2] = { pred_cu->violates_lfnst_constrained[0], bool constraints[2] = { pred_cu->violates_lfnst_constrained[0],
@ -440,12 +441,24 @@ static double search_intra_trdepth(
if (constraints[0] || !constraints[1]) { if (constraints[0] || !constraints[1]) {
continue; continue;
} }
double rd_cost = uvg_cu_rd_cost_luma(state, lcu_px.x, lcu_px.y, depth, pred_cu, lcu);
double mts_bits = 0;
if (num_transforms > 1 && trafo != MTS_SKIP && width <= 32 /*&& height <= 32*/
&& !pred_cu->violates_mts_coeff_constraint && pred_cu->mts_last_scan_pos) {
bool symbol = trafo != 0;
int ctx_idx = 0;
mts_bits += CTX_ENTROPY_FBITS(&state->search_cabac.ctx.mts_idx_model[ctx_idx], symbol);
ctx_idx++;
for (int i = 0; i < 3 && symbol; i++, ctx_idx++)
{
symbol = trafo > i + MTS_DST7_DST7 ? 1 : 0;
mts_bits += CTX_ENTROPY_FBITS(&state->search_cabac.ctx.mts_idx_model[ctx_idx], symbol);
} }
double rd_cost = uvg_cu_rd_cost_luma(state, lcu_px.x, lcu_px.y, depth, pred_cu, lcu); }
//if (reconstruct_chroma) { rd_cost += mts_bits * state->frame->lambda;
// rd_cost += uvg_cu_rd_cost_chroma(state, lcu_px.x, lcu_px.y, depth, pred_cu, lcu);
//}
// TODO: there is an error in this cost calculation. This will be fixed when merged into current master // TODO: there is an error in this cost calculation. This will be fixed when merged into current master
// This is compared to the previous best, which may have chroma cost included // This is compared to the previous best, which may have chroma cost included
@ -509,6 +522,7 @@ static double search_intra_trdepth(
uvg_pixels_blit(lcu->rec.v, nosplit_pixels.v, width_c, width_c, LCU_WIDTH_C, width_c); uvg_pixels_blit(lcu->rec.v, nosplit_pixels.v, width_c, width_c, LCU_WIDTH_C, width_c);
} }
} }
}
// Recurse further if all of the following: // Recurse further if all of the following:
// - Current depth is less than maximum depth of the search (max_depth). // - Current depth is less than maximum depth of the search (max_depth).
@ -910,9 +924,9 @@ static double count_bits(
static int16_t search_intra_rough( static int16_t search_intra_rough(
encoder_state_t * const state, encoder_state_t * const state,
const cu_loc_t* const cu_loc, const cu_loc_t* const cu_loc,
kvz_pixel *orig, uvg_pixel *orig,
int32_t origstride, int32_t origstride,
kvz_intra_references *refs, uvg_intra_references *refs,
int log2_width, int log2_width,
int8_t *intra_preds, int8_t *intra_preds,
intra_search_data_t* modes_out, intra_search_data_t* modes_out,
@ -924,23 +938,23 @@ static int16_t search_intra_rough(
int_fast8_t width = 1 << log2_width; int_fast8_t width = 1 << log2_width;
// cost_pixel_nxn_func *satd_func = kvz_pixels_get_satd_func(width); // cost_pixel_nxn_func *satd_func = kvz_pixels_get_satd_func(width);
// cost_pixel_nxn_func *sad_func = kvz_pixels_get_sad_func(width); // cost_pixel_nxn_func *sad_func = kvz_pixels_get_sad_func(width);
cost_pixel_nxn_multi_func *satd_dual_func = kvz_pixels_get_satd_dual_func(width); cost_pixel_nxn_multi_func *satd_dual_func = uvg_pixels_get_satd_dual_func(width);
cost_pixel_nxn_multi_func *sad_dual_func = kvz_pixels_get_sad_dual_func(width); cost_pixel_nxn_multi_func *sad_dual_func = uvg_pixels_get_sad_dual_func(width);
bool mode_checked[KVZ_NUM_INTRA_MODES] = {0}; bool mode_checked[UVG_NUM_INTRA_MODES] = {0};
double costs[KVZ_NUM_INTRA_MODES]; double costs[UVG_NUM_INTRA_MODES];
// const kvz_config *cfg = &state->encoder_control->cfg; // const kvz_config *cfg = &state->encoder_control->cfg;
// const bool filter_boundary = !(cfg->lossless && cfg->implicit_rdpcm); // const bool filter_boundary = !(cfg->lossless && cfg->implicit_rdpcm);
// Temporary block arrays // Temporary block arrays
kvz_pixel _preds[PARALLEL_BLKS * 32 * 32 + SIMD_ALIGNMENT]; uvg_pixel _preds[PARALLEL_BLKS * 32 * 32 + SIMD_ALIGNMENT];
pred_buffer preds = ALIGNED_POINTER(_preds, SIMD_ALIGNMENT); pred_buffer preds = ALIGNED_POINTER(_preds, SIMD_ALIGNMENT);
kvz_pixel _orig_block[32 * 32 + SIMD_ALIGNMENT]; uvg_pixel _orig_block[32 * 32 + SIMD_ALIGNMENT];
kvz_pixel *orig_block = ALIGNED_POINTER(_orig_block, SIMD_ALIGNMENT); uvg_pixel *orig_block = ALIGNED_POINTER(_orig_block, SIMD_ALIGNMENT);
// Store original block for SAD computation // Store original block for SAD computation
kvz_pixels_blit(orig, orig_block, width, width, origstride, width); uvg_pixels_blit(orig, orig_block, width, width, origstride, width);
int8_t modes_selected = 0; int8_t modes_selected = 0;
// Note: get_cost and get_cost_dual may return negative costs. // Note: get_cost and get_cost_dual may return negative costs.
@ -973,9 +987,9 @@ static int16_t search_intra_rough(
int offset = 4; int offset = 4;
search_proxy.pred_cu.intra.mode = 0; search_proxy.pred_cu.intra.mode = 0;
kvz_intra_predict(state, refs, &loc, COLOR_Y, preds[0], &search_proxy, NULL); uvg_intra_predict(state, refs, &loc, COLOR_Y, preds[0], &search_proxy, NULL);
search_proxy.pred_cu.intra.mode = 1; search_proxy.pred_cu.intra.mode = 1;
kvz_intra_predict(state, refs, &loc, COLOR_Y, preds[1], &search_proxy, NULL); uvg_intra_predict(state, refs, &loc, COLOR_Y, preds[1], &search_proxy, NULL);
get_cost_dual(state, preds, orig_block, satd_dual_func, sad_dual_func, width, costs); get_cost_dual(state, preds, orig_block, satd_dual_func, sad_dual_func, width, costs);
mode_checked[0] = true; mode_checked[0] = true;
mode_checked[1] = true; mode_checked[1] = true;
@ -1025,7 +1039,7 @@ static int16_t search_intra_rough(
for (int i = 0; i < PARALLEL_BLKS; ++i) { for (int i = 0; i < PARALLEL_BLKS; ++i) {
if (mode + i * offset <= 66) { if (mode + i * offset <= 66) {
search_proxy.pred_cu.intra.mode = mode + i*offset; search_proxy.pred_cu.intra.mode = mode + i*offset;
kvz_intra_predict(state, refs, &loc, COLOR_Y, preds[i], &search_proxy, NULL); uvg_intra_predict(state, refs, &loc, COLOR_Y, preds[i], &search_proxy, NULL);
} }
} }
@ -1097,7 +1111,7 @@ static int16_t search_intra_rough(
for (int block = 0; block < PARALLEL_BLKS; ++block) { for (int block = 0; block < PARALLEL_BLKS; ++block) {
search_proxy.pred_cu.intra.mode = modes_to_check[block + i]; search_proxy.pred_cu.intra.mode = modes_to_check[block + i];
kvz_intra_predict(state, refs, &loc, COLOR_Y, preds[block], &search_proxy, NULL); uvg_intra_predict(state, refs, &loc, COLOR_Y, preds[block], &search_proxy, NULL);
} }
@ -1766,9 +1780,6 @@ void uvg_search_cu_intra(
number_of_modes_to_search, number_of_modes_to_search,
search_data, search_data,
lcu); lcu);
// Reset these
search_data[0].pred_cu.violates_mts_coeff_constraint = false;
search_data[0].pred_cu.mts_last_scan_pos = false;
} }
*mode_out = search_data[0]; *mode_out = search_data[0];
} }