From 418c65cbf1f040a9a5ce8e35f0f3baa76bf13c8e Mon Sep 17 00:00:00 2001 From: Joose Sainio Date: Fri, 13 May 2022 15:04:03 +0300 Subject: [PATCH] [intra] count mts and tr_skip bits --- src/encode_coding_tree.c | 9 ++- src/encode_coding_tree.h | 4 +- src/search.c | 44 ++++++----- src/search_intra.c | 167 +++++++++++++++++++++------------------ 4 files changed, 120 insertions(+), 104 deletions(-) diff --git a/src/encode_coding_tree.c b/src/encode_coding_tree.c index e05a1c28..aa939380 100644 --- a/src/encode_coding_tree.c +++ b/src/encode_coding_tree.c @@ -47,9 +47,9 @@ #include "tables.h" #include "videoframe.h" -static bool is_mts_allowed(encoder_state_t * const state, cu_info_t *const pred_cu) +bool uvg_is_mts_allowed(encoder_state_t * const state, cu_info_t *const pred_cu) { - uint32_t ts_max_size = 1 << 2; //cu.cs->sps->getLog2MaxTransformSkipBlockSize(); + uint32_t ts_max_size = 1 << state->encoder_control->cfg.trskip_max_size; const uint32_t max_size = 32; // CU::isIntra(cu) ? MTS_INTRA_MAX_CU_SIZE : MTS_INTER_MAX_CU_SIZE; const uint32_t cu_width = LCU_WIDTH >> pred_cu->depth; const uint32_t cu_height = LCU_WIDTH >> pred_cu->depth; @@ -61,6 +61,7 @@ static bool is_mts_allowed(encoder_state_t * const state, cu_info_t *const pred_ //mts_allowed &= !cu.ispMode; // ISP_TODO: Uncomment this when ISP is implemented. //mts_allowed &= !cu.sbtInfo; mts_allowed &= !(pred_cu->bdpcmMode && cu_width <= ts_max_size && cu_height <= ts_max_size); + mts_allowed &= pred_cu->tr_idx != MTS_SKIP && !pred_cu->violates_mts_coeff_constraint && pred_cu->mts_last_scan_pos; return mts_allowed; } @@ -71,7 +72,7 @@ static void encode_mts_idx(encoder_state_t * const state, //TransformUnit &tu = *cu.firstTU; int mts_idx = pred_cu->tr_idx; - if (is_mts_allowed(state, (cu_info_t* const )pred_cu) && mts_idx != MTS_SKIP + if (uvg_is_mts_allowed(state, (cu_info_t* const )pred_cu) && mts_idx != MTS_SKIP && !pred_cu->violates_mts_coeff_constraint && pred_cu->mts_last_scan_pos && pred_cu->lfnst_idx == 0 @@ -718,7 +719,7 @@ static void encode_transform_coeff(encoder_state_t * const state, if ((cur_cu->type == CU_INTRA || tr_depth > 0 || cb_flag_u || cb_flag_v) && !only_chroma) { cabac->cur_ctx = &(cabac->ctx.qt_cbf_model_luma[0]); CABAC_BIN(cabac, cb_flag_y, "cbf_luma"); - // printf("%hu %hu %d %d\n", cabac->ctx.qt_cbf_model_luma[0].state[0], cabac->ctx.qt_cbf_model_luma[0].state[1], x, y); + printf("%hu %hu %d %d\n", cabac->ctx.qt_cbf_model_luma[0].state[0], cabac->ctx.qt_cbf_model_luma[0].state[1], x, y); } if (cb_flag_y | cb_flag_u | cb_flag_v) { diff --git a/src/encode_coding_tree.h b/src/encode_coding_tree.h index 11989a63..23c927c9 100644 --- a/src/encode_coding_tree.h +++ b/src/encode_coding_tree.h @@ -40,7 +40,9 @@ #include "encoderstate.h" #include "global.h" -void uvg_encode_coding_tree(encoder_state_t * const state, +bool uvg_is_mts_allowed(encoder_state_t* const state, cu_info_t* const pred_cu); + +void kvz_encode_coding_tree(encoder_state_t * const state, uint16_t x_ctb, uint16_t y_ctb, uint8_t depth, diff --git a/src/search.c b/src/search.c index 945b3219..1deebc94 100644 --- a/src/search.c +++ b/src/search.c @@ -324,38 +324,23 @@ double uvg_cu_rd_cost_luma(const encoder_state_t *const state, return sum + tr_tree_bits * state->lambda; } - - if (cabac->update && tr_cu->tr_depth == tr_cu->depth && !skip_residual_coding) { - // Because these need to be coded before the luma cbf they also need to be counted - // before the cabac state changes. However, since this branch is only executed when - // calculating the last RD cost it is not problem to include the chroma cbf costs in - // luma, because the chroma cost is calculated right after the luma cost. - // However, if we have different tr_depth, the bits cannot be written in correct - // order anyways so do not touch the chroma cbf here. - if (state->encoder_control->chroma_format != UVG_CSP_400) { - cabac_ctx_t* cr_ctx = &(cabac->ctx.qt_cbf_model_cb[0]); - cabac->cur_ctx = cr_ctx; - int u_is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_U); - int v_is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_V); - CABAC_FBITS_UPDATE(cabac, cr_ctx, u_is_set, tr_tree_bits, "cbf_cb_search"); - cr_ctx = &(cabac->ctx.qt_cbf_model_cr[u_is_set]); - CABAC_FBITS_UPDATE(cabac, cr_ctx, v_is_set, tr_tree_bits, "cbf_cb_search"); - } - } - // Add transform_tree cbf_luma bit cost. const int is_tr_split = tr_cu->tr_depth - tr_cu->depth; + int is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_Y); if (pred_cu->type == CU_INTRA || is_tr_split || cbf_is_set(tr_cu->cbf, depth, COLOR_U) || cbf_is_set(tr_cu->cbf, depth, COLOR_V)) { cabac_ctx_t *ctx = &(cabac->ctx.qt_cbf_model_luma[0]); - int is_set = cbf_is_set(pred_cu->cbf, depth, COLOR_Y); CABAC_FBITS_UPDATE(cabac, ctx, is_set, tr_tree_bits, "cbf_y_search"); } + if (is_set && state->encoder_control->cfg.trskip_enable && width <= (1 << state->encoder_control->cfg.trskip_max_size)) { + CABAC_FBITS_UPDATE(cabac, &cabac->ctx.transform_skip_model_luma, tr_cu->tr_idx == MTS_SKIP, tr_tree_bits, "transform_skip_flag"); + } + // SSD between reconstruction and original int ssd = 0; if (!state->encoder_control->cfg.lossless) { @@ -555,7 +540,10 @@ static double cu_rd_cost_tr_split_accurate(const encoder_state_t* const state, width); } - { + if(cb_flag_y){ + if (state->encoder_control->cfg.trskip_enable && width <= (1 << state->encoder_control->cfg.trskip_max_size)) { + CABAC_FBITS_UPDATE(cabac, &cabac->ctx.transform_skip_model_luma, tr_cu->tr_idx == MTS_SKIP, tr_tree_bits, "transform_skip_flag"); + } int8_t luma_scan_mode = uvg_get_scan_order(pred_cu->type, pred_cu->intra.mode, depth); const coeff_t* coeffs = &lcu->coeff.y[xy_to_zorder(LCU_WIDTH, x_px, y_px)]; @@ -596,6 +584,20 @@ static double cu_rd_cost_tr_split_accurate(const encoder_state_t* const state, coeff_bits += uvg_get_coeff_cost(state, &lcu->coeff.joint_uv[index], width, 2, scan_order, 0); } } + if (kvz_is_mts_allowed(state, tr_cu)) { + + bool symbol = tr_cu->tr_idx != 0; + int ctx_idx = 0; + CABAC_FBITS_UPDATE(cabac, &state->search_cabac.ctx.mts_idx_model[ctx_idx], symbol, tr_tree_bits, "mts_idx"); + + ctx_idx++; + for (int i = 0; i < 3 && symbol; i++, ctx_idx++) + { + symbol = tr_cu->tr_idx > i + MTS_DST7_DST7 ? 1 : 0; + CABAC_FBITS_UPDATE(cabac, &state->search_cabac.ctx.mts_idx_model[ctx_idx], symbol, tr_tree_bits, "mts_idx"); + } + + } double bits = tr_tree_bits + coeff_bits; return luma_ssd * UVG_LUMA_MULT + chroma_ssd * UVG_CHROMA_MULT + bits * state->lambda; diff --git a/src/search_intra.c b/src/search_intra.c index 23c36a5e..25f806aa 100644 --- a/src/search_intra.c +++ b/src/search_intra.c @@ -381,7 +381,7 @@ static double search_intra_trdepth( } const int mts_start = trafo; //TODO: height - if(state->encoder_control->cfg.trskip_enable && width <= (1 << state->encoder_control->cfg.trskip_max_size) /*&& height == 4*/) { + if (state->encoder_control->cfg.trskip_enable && width <= (1 << state->encoder_control->cfg.trskip_max_size) /*&& height == 4*/) { num_transforms = MAX(num_transforms, 2); } pred_cu->intra.mode_chroma = -1; @@ -424,14 +424,15 @@ static double search_intra_trdepth( lcu); // TODO: Not sure if this should be 0 or 1 but at least seems to work with 1 + derive_mts_constraints(pred_cu, lcu, depth, lcu_px); if (pred_cu->tr_idx > 1) { - derive_mts_constraints(pred_cu, lcu, depth, lcu_px); if (pred_cu->violates_mts_coeff_constraint || !pred_cu->mts_last_scan_pos) { continue; } } + if (pred_cu->lfnst_idx > 0) { // Temp constraints. Updating the actual pred_cu constraints here will break things later bool constraints[2] = { pred_cu->violates_lfnst_constrained[0], @@ -440,73 +441,86 @@ static double search_intra_trdepth( if (constraints[0] || !constraints[1]) { continue; } - } + double rd_cost = uvg_cu_rd_cost_luma(state, lcu_px.x, lcu_px.y, depth, pred_cu, lcu); + double mts_bits = 0; + if (num_transforms > 1 && trafo != MTS_SKIP && width <= 32 /*&& height <= 32*/ + && !pred_cu->violates_mts_coeff_constraint && pred_cu->mts_last_scan_pos) { - double rd_cost = uvg_cu_rd_cost_luma(state, lcu_px.x, lcu_px.y, depth, pred_cu, lcu); - //if (reconstruct_chroma) { - // rd_cost += uvg_cu_rd_cost_chroma(state, lcu_px.x, lcu_px.y, depth, pred_cu, lcu); - //} + bool symbol = trafo != 0; + int ctx_idx = 0; + mts_bits += CTX_ENTROPY_FBITS(&state->search_cabac.ctx.mts_idx_model[ctx_idx], symbol); - // TODO: there is an error in this cost calculation. This will be fixed when merged into current master - // This is compared to the previous best, which may have chroma cost included - if (rd_cost < best_rd_cost) { - best_rd_cost = rd_cost; - best_lfnst_idx = pred_cu->lfnst_idx; - best_tr_idx = pred_cu->tr_idx; - if (best_tr_idx == MTS_SKIP) break; // Very unlikely that further search is necessary if skip seems best option - } - } // end mts index loop (tr_idx) - if (reconstruct_chroma) { - int8_t luma_mode = pred_cu->intra.mode; - pred_cu->intra.mode = -1; - pred_cu->intra.mode_chroma = chroma_mode; - pred_cu->joint_cb_cr = 4; // TODO: Maybe check the jccr mode here also but holy shit is the interface of search_intra_rdo bad currently - uvg_intra_recon_cu(state, - x_px, y_px, - depth, search_data, - pred_cu, - lcu); - best_rd_cost += uvg_cu_rd_cost_chroma(state, lcu_px.x, lcu_px.y, depth, pred_cu, lcu); - pred_cu->intra.mode = luma_mode; + ctx_idx++; + for (int i = 0; i < 3 && symbol; i++, ctx_idx++) + { + symbol = trafo > i + MTS_DST7_DST7 ? 1 : 0; + mts_bits += CTX_ENTROPY_FBITS(&state->search_cabac.ctx.mts_idx_model[ctx_idx], symbol); + } - // Check lfnst constraints for chroma - if (pred_cu->lfnst_idx > 0) { - // Temp constraints. Updating the actual pred_cu constraints here will break things later - bool constraints[2] = { pred_cu->violates_lfnst_constrained[1], - pred_cu->lfnst_last_scan_pos }; - derive_lfnst_constraints(pred_cu, lcu, depth, COLOR_U, lcu_px, constraints); - if (constraints[0] || !constraints[1]) { - best_lfnst_idx = 0; - continue; } - derive_lfnst_constraints(pred_cu, lcu, depth, COLOR_V, lcu_px, constraints); - if (constraints[0] || !constraints[1]) { - best_lfnst_idx = 0; - continue; + rd_cost += mts_bits * state->frame->lambda; + + // TODO: there is an error in this cost calculation. This will be fixed when merged into current master + // This is compared to the previous best, which may have chroma cost included + if (rd_cost < best_rd_cost) { + best_rd_cost = rd_cost; + best_lfnst_idx = pred_cu->lfnst_idx; + best_tr_idx = pred_cu->tr_idx; + if (best_tr_idx == MTS_SKIP) break; // Very unlikely that further search is necessary if skip seems best option + } + } // end mts index loop (tr_idx) + if (reconstruct_chroma) { + int8_t luma_mode = pred_cu->intra.mode; + pred_cu->intra.mode = -1; + pred_cu->intra.mode_chroma = chroma_mode; + pred_cu->joint_cb_cr = 4; // TODO: Maybe check the jccr mode here also but holy shit is the interface of search_intra_rdo bad currently + uvg_intra_recon_cu(state, + x_px, y_px, + depth, search_data, + pred_cu, + lcu); + best_rd_cost += uvg_cu_rd_cost_chroma(state, lcu_px.x, lcu_px.y, depth, pred_cu, lcu); + pred_cu->intra.mode = luma_mode; + + // Check lfnst constraints for chroma + if (pred_cu->lfnst_idx > 0) { + // Temp constraints. Updating the actual pred_cu constraints here will break things later + bool constraints[2] = { pred_cu->violates_lfnst_constrained[1], + pred_cu->lfnst_last_scan_pos }; + derive_lfnst_constraints(pred_cu, lcu, depth, COLOR_U, lcu_px, constraints); + if (constraints[0] || !constraints[1]) { + best_lfnst_idx = 0; + continue; + } + derive_lfnst_constraints(pred_cu, lcu, depth, COLOR_V, lcu_px, constraints); + if (constraints[0] || !constraints[1]) { + best_lfnst_idx = 0; + continue; + } } } + if (best_tr_idx == MTS_SKIP) break; // Very unlikely that further search is necessary if skip seems best option + } // end lfnst_index loop + + pred_cu->tr_skip = best_tr_idx == MTS_SKIP; + pred_cu->tr_idx = best_tr_idx; + pred_cu->lfnst_idx = best_lfnst_idx; + nosplit_cost += best_rd_cost; + + // Early stop condition for the recursive search. + // If the cost of any 1/4th of the transform is already larger than the + // whole transform, assume that splitting further is a bad idea. + if (nosplit_cost >= cost_treshold) { + return nosplit_cost; } - if (best_tr_idx == MTS_SKIP) break; // Very unlikely that further search is necessary if skip seems best option - } // end lfnst_index loop - - pred_cu->tr_skip = best_tr_idx == MTS_SKIP; - pred_cu->tr_idx = best_tr_idx; - pred_cu->lfnst_idx = best_lfnst_idx; - nosplit_cost += best_rd_cost; - - // Early stop condition for the recursive search. - // If the cost of any 1/4th of the transform is already larger than the - // whole transform, assume that splitting further is a bad idea. - if (nosplit_cost >= cost_treshold) { - return nosplit_cost; - } - nosplit_cbf = pred_cu->cbf; + nosplit_cbf = pred_cu->cbf; - uvg_pixels_blit(lcu->rec.y, nosplit_pixels.y, width, width, LCU_WIDTH, width); - if (reconstruct_chroma) { - uvg_pixels_blit(lcu->rec.u, nosplit_pixels.u, width_c, width_c, LCU_WIDTH_C, width_c); - uvg_pixels_blit(lcu->rec.v, nosplit_pixels.v, width_c, width_c, LCU_WIDTH_C, width_c); + uvg_pixels_blit(lcu->rec.y, nosplit_pixels.y, width, width, LCU_WIDTH, width); + if (reconstruct_chroma) { + uvg_pixels_blit(lcu->rec.u, nosplit_pixels.u, width_c, width_c, LCU_WIDTH_C, width_c); + uvg_pixels_blit(lcu->rec.v, nosplit_pixels.v, width_c, width_c, LCU_WIDTH_C, width_c); + } } } @@ -910,9 +924,9 @@ static double count_bits( static int16_t search_intra_rough( encoder_state_t * const state, const cu_loc_t* const cu_loc, - kvz_pixel *orig, + uvg_pixel *orig, int32_t origstride, - kvz_intra_references *refs, + uvg_intra_references *refs, int log2_width, int8_t *intra_preds, intra_search_data_t* modes_out, @@ -924,23 +938,23 @@ static int16_t search_intra_rough( int_fast8_t width = 1 << log2_width; // cost_pixel_nxn_func *satd_func = kvz_pixels_get_satd_func(width); // cost_pixel_nxn_func *sad_func = kvz_pixels_get_sad_func(width); - cost_pixel_nxn_multi_func *satd_dual_func = kvz_pixels_get_satd_dual_func(width); - cost_pixel_nxn_multi_func *sad_dual_func = kvz_pixels_get_sad_dual_func(width); - bool mode_checked[KVZ_NUM_INTRA_MODES] = {0}; - double costs[KVZ_NUM_INTRA_MODES]; + cost_pixel_nxn_multi_func *satd_dual_func = uvg_pixels_get_satd_dual_func(width); + cost_pixel_nxn_multi_func *sad_dual_func = uvg_pixels_get_sad_dual_func(width); + bool mode_checked[UVG_NUM_INTRA_MODES] = {0}; + double costs[UVG_NUM_INTRA_MODES]; // const kvz_config *cfg = &state->encoder_control->cfg; // const bool filter_boundary = !(cfg->lossless && cfg->implicit_rdpcm); // Temporary block arrays - kvz_pixel _preds[PARALLEL_BLKS * 32 * 32 + SIMD_ALIGNMENT]; + uvg_pixel _preds[PARALLEL_BLKS * 32 * 32 + SIMD_ALIGNMENT]; pred_buffer preds = ALIGNED_POINTER(_preds, SIMD_ALIGNMENT); - kvz_pixel _orig_block[32 * 32 + SIMD_ALIGNMENT]; - kvz_pixel *orig_block = ALIGNED_POINTER(_orig_block, SIMD_ALIGNMENT); + uvg_pixel _orig_block[32 * 32 + SIMD_ALIGNMENT]; + uvg_pixel *orig_block = ALIGNED_POINTER(_orig_block, SIMD_ALIGNMENT); // Store original block for SAD computation - kvz_pixels_blit(orig, orig_block, width, width, origstride, width); + uvg_pixels_blit(orig, orig_block, width, width, origstride, width); int8_t modes_selected = 0; // Note: get_cost and get_cost_dual may return negative costs. @@ -973,9 +987,9 @@ static int16_t search_intra_rough( int offset = 4; search_proxy.pred_cu.intra.mode = 0; - kvz_intra_predict(state, refs, &loc, COLOR_Y, preds[0], &search_proxy, NULL); + uvg_intra_predict(state, refs, &loc, COLOR_Y, preds[0], &search_proxy, NULL); search_proxy.pred_cu.intra.mode = 1; - kvz_intra_predict(state, refs, &loc, COLOR_Y, preds[1], &search_proxy, NULL); + uvg_intra_predict(state, refs, &loc, COLOR_Y, preds[1], &search_proxy, NULL); get_cost_dual(state, preds, orig_block, satd_dual_func, sad_dual_func, width, costs); mode_checked[0] = true; mode_checked[1] = true; @@ -1025,7 +1039,7 @@ static int16_t search_intra_rough( for (int i = 0; i < PARALLEL_BLKS; ++i) { if (mode + i * offset <= 66) { search_proxy.pred_cu.intra.mode = mode + i*offset; - kvz_intra_predict(state, refs, &loc, COLOR_Y, preds[i], &search_proxy, NULL); + uvg_intra_predict(state, refs, &loc, COLOR_Y, preds[i], &search_proxy, NULL); } } @@ -1097,7 +1111,7 @@ static int16_t search_intra_rough( for (int block = 0; block < PARALLEL_BLKS; ++block) { search_proxy.pred_cu.intra.mode = modes_to_check[block + i]; - kvz_intra_predict(state, refs, &loc, COLOR_Y, preds[block], &search_proxy, NULL); + uvg_intra_predict(state, refs, &loc, COLOR_Y, preds[block], &search_proxy, NULL); } @@ -1765,10 +1779,7 @@ void uvg_search_cu_intra( depth, number_of_modes_to_search, search_data, - lcu); - // Reset these - search_data[0].pred_cu.violates_mts_coeff_constraint = false; - search_data[0].pred_cu.mts_last_scan_pos = false; + lcu); } *mode_out = search_data[0]; }