diff --git a/src/cu.h b/src/cu.h index d3e38871..7b60805b 100644 --- a/src/cu.h +++ b/src/cu.h @@ -134,7 +134,6 @@ typedef struct int8_t tr_skip; //!< \brief transform skip flag } intra[4]; struct { - uint32_t bitcost; int16_t mv[2][2]; // \brief Motion vectors for L0 and L1 int16_t mvd[2][2]; // \brief Motion vector differences for L0 and L1 uint8_t mv_cand[2]; // \brief selected MV candidate diff --git a/src/search.c b/src/search.c index 1d5b6bec..e30a1ef5 100644 --- a/src/search.c +++ b/src/search.c @@ -446,22 +446,17 @@ static double calc_mode_bits(const encoder_state_t *state, const cu_info_t * cur_cu, int x, int y) { - double mode_bits; - - if (cur_cu->type == CU_INTER) { - mode_bits = cur_cu->inter.bitcost; - } else { - int8_t candidate_modes[3]; - { - const cu_info_t *left_cu = ((x >= 8) ? CU_GET_CU(cur_cu, -1, 0) : NULL); - const cu_info_t *above_cu = ((y >= 8) ? CU_GET_CU(cur_cu, 0, -1) : NULL); - kvz_intra_get_dir_luma_predictor(x, y, candidate_modes, cur_cu, left_cu, above_cu); - } + assert(cur_cu->type == CU_INTRA); + int8_t candidate_modes[3]; + { + const cu_info_t *left_cu = ((x >= 8) ? CU_GET_CU(cur_cu, -1, 0) : NULL); + const cu_info_t *above_cu = ((y >= 8) ? CU_GET_CU(cur_cu, 0, -1) : NULL); + kvz_intra_get_dir_luma_predictor(x, y, candidate_modes, cur_cu, left_cu, above_cu); + } - mode_bits = kvz_luma_mode_bits(state, cur_cu->intra[PU_INDEX(x >> 2, y >> 2)].mode, candidate_modes); - if (PU_INDEX(x >> 2, y >> 2) == 0) { - mode_bits += kvz_chroma_mode_bits(state, cur_cu->intra[0].mode_chroma, cur_cu->intra[PU_INDEX(x >> 2, y >> 2)].mode); - } + double mode_bits = kvz_luma_mode_bits(state, cur_cu->intra[PU_INDEX(x >> 2, y >> 2)].mode, candidate_modes); + if (PU_INDEX(x >> 2, y >> 2) == 0) { + mode_bits += kvz_chroma_mode_bits(state, cur_cu->intra[0].mode_chroma, cur_cu->intra[PU_INDEX(x >> 2, y >> 2)].mode); } return mode_bits; @@ -492,6 +487,7 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth, const videoframe_t * const frame = state->tile->frame; int cu_width = LCU_WIDTH >> depth; double cost = MAX_INT; + uint32_t inter_bitcost = MAX_INT; cu_info_t *cur_cu; lcu_t *const lcu = &work_tree[depth]; @@ -526,9 +522,16 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth, && WITHIN(depth, ctrl->pu_depth_inter.min, ctrl->pu_depth_inter.max); if (can_use_inter) { - double mode_cost = kvz_search_cu_inter(state, x, y, depth, &work_tree[depth]); + double mode_cost; + uint32_t mode_bitcost; + kvz_search_cu_inter(state, + x, y, + depth, + &work_tree[depth], + &mode_cost, &mode_bitcost); if (mode_cost < cost) { cost = mode_cost; + inter_bitcost = mode_bitcost; cur_cu->type = CU_INTER; } @@ -545,14 +548,16 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth, const int first_mode = ctrl->cfg->smp_enable ? 0 : 2; const int last_mode = (ctrl->cfg->amp_enable && cu_width >= 32) ? 5 : 1; for (int i = first_mode; i <= last_mode; ++i) { - mode_cost = kvz_search_cu_smp(state, - x, y, - depth, - mp_modes[i], - &work_tree[depth + 1]); + kvz_search_cu_smp(state, + x, y, + depth, + mp_modes[i], + &work_tree[depth + 1], + &mode_cost, &mode_bitcost); // TODO: take cost of coding part mode into account if (mode_cost < cost) { cost = mode_cost; + inter_bitcost = mode_bitcost; // TODO: only copy inter prediction info, not pixels work_tree_copy_up(x, y, depth, work_tree); } @@ -655,8 +660,8 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth, cur_cu->merged = 0; cur_cu->skipped = 1; // Selecting skip reduces bits needed to code the CU - if (cur_cu->inter.bitcost > 1) { - cur_cu->inter.bitcost -= 1; + if (inter_bitcost > 1) { + inter_bitcost -= 1; } } lcu_set_inter(&work_tree[depth], x, y, depth, cur_cu); @@ -666,7 +671,14 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth, if (cur_cu->type == CU_INTRA || cur_cu->type == CU_INTER) { cost = kvz_cu_rd_cost_luma(state, x_local, y_local, depth, cur_cu, &work_tree[depth]); cost += kvz_cu_rd_cost_chroma(state, x_local, y_local, depth, cur_cu, &work_tree[depth]); - double mode_bits = calc_mode_bits(state, cur_cu, x, y); + + double mode_bits; + if (cur_cu->type == CU_INTRA) { + mode_bits = calc_mode_bits(state, cur_cu, x, y); + } else { + mode_bits = inter_bitcost; + } + cost += mode_bits * state->global->cur_lambda_cost; } diff --git a/src/search_inter.c b/src/search_inter.c index c52e3c05..1de883a1 100644 --- a/src/search_inter.c +++ b/src/search_inter.c @@ -1093,7 +1093,8 @@ static void search_pu_inter_ref(encoder_state_t * const state, int16_t num_cand, unsigned ref_idx, uint32_t(*get_mvd_cost)(encoder_state_t * const, vector2d_t *, cabac_data_t*), - double *inter_cost) + double *inter_cost, + uint32_t *inter_bitcost) { const int x_cu = x >> 3; const int y_cu = y >> 3; @@ -1250,9 +1251,6 @@ static void search_pu_inter_ref(encoder_state_t * const state, mvd.y = mv.y - mv_cand[cu_mv_cand][1]; if (temp_cost < *inter_cost) { - - *inter_cost = temp_cost; - // Map reference index to L0/L1 pictures cur_cu->inter.mv_dir = ref_list+1; cur_cu->inter.mv_ref_coded[ref_list] = state->global->refmap[ref_idx].idx; @@ -1264,8 +1262,10 @@ static void search_pu_inter_ref(encoder_state_t * const state, cur_cu->inter.mv[ref_list][1] = (int16_t)mv.y; cur_cu->inter.mvd[ref_list][0] = (int16_t)mvd.x; cur_cu->inter.mvd[ref_list][1] = (int16_t)mvd.y; - cur_cu->inter.bitcost = temp_bitcost + cur_cu->inter.mv_dir - 1 + cur_cu->inter.mv_ref_coded[ref_list]; cur_cu->inter.mv_cand[ref_list] = cu_mv_cand; + + *inter_cost = temp_cost; + *inter_bitcost = temp_bitcost + cur_cu->inter.mv_dir - 1 + cur_cu->inter.mv_ref_coded[ref_list]; } } @@ -1281,15 +1281,21 @@ static void search_pu_inter_ref(encoder_state_t * const state, * \param i_pu index of the PU in the CU * \param lcu containing LCU * - * \return cost of the best mode + * \param inter_cost Return inter cost of the best mode + * \param inter_bitcost Return inter bitcost of the best mode */ -static int search_pu_inter(encoder_state_t * const state, - int x_cu, int y_cu, - int depth, - part_mode_t part_mode, - int i_pu, - lcu_t *lcu) +static void search_pu_inter(encoder_state_t * const state, + int x_cu, int y_cu, + int depth, + part_mode_t part_mode, + int i_pu, + lcu_t *lcu, + double *inter_cost, + uint32_t *inter_bitcost) { + *inter_cost = MAX_INT; + *inter_bitcost = MAX_INT; + const videoframe_t * const frame = state->tile->frame; const int width_cu = LCU_WIDTH >> depth; const int x = PU_GET_X(part_mode, width_cu, x_cu, i_pu); @@ -1341,8 +1347,6 @@ static int search_pu_inter(encoder_state_t * const state, cur_cu->inter.mv_cand[0] = 0; cur_cu->inter.mv_cand[1] = 0; - double inter_cost = INT_MAX; - uint32_t ref_idx; for (ref_idx = 0; ref_idx < state->global->ref->used_size; ref_idx++) { search_pu_inter_ref(state, @@ -1353,7 +1357,8 @@ static int search_pu_inter(encoder_state_t * const state, mv_cand, merge_cand, num_cand, ref_idx, get_mvd_cost, - &inter_cost); + inter_cost, + inter_bitcost); } // Search bi-pred positions @@ -1430,7 +1435,7 @@ static int search_pu_inter(encoder_state_t * const state, cost += calc_mvd(state, merge_cand[i].mv[0][0], merge_cand[i].mv[0][1], 0, mv_cand, merge_cand, 0, ref_idx, &bitcost[0]); cost += calc_mvd(state, merge_cand[i].mv[1][0], merge_cand[i].mv[1][1], 0, mv_cand, merge_cand, 0, ref_idx, &bitcost[1]); - if (cost < inter_cost) { + if (cost < *inter_cost) { cur_cu->inter.mv_dir = 3; cur_cu->inter.mv_ref_coded[0] = state->global->refmap[merge_cand[i].ref[0]].idx; @@ -1487,8 +1492,8 @@ static int search_pu_inter(encoder_state_t * const state, cur_cu->inter.mvd[reflist][1] = cur_cu->inter.mv[reflist][1] - mv_cand[cu_mv_cand][1]; cur_cu->inter.mv_cand[reflist] = cu_mv_cand; } - inter_cost = cost; - cur_cu->inter.bitcost = bitcost[0] + bitcost[1] + cur_cu->inter.mv_dir - 1 + cur_cu->inter.mv_ref_coded[0] + cur_cu->inter.mv_ref_coded[1]; + *inter_cost = cost; + *inter_bitcost = bitcost[0] + bitcost[1] + cur_cu->inter.mv_dir - 1 + cur_cu->inter.mv_ref_coded[0] + cur_cu->inter.mv_ref_coded[1]; } } } @@ -1496,14 +1501,12 @@ static int search_pu_inter(encoder_state_t * const state, FREE_POINTER(templcu); } - if (inter_cost < INT_MAX) { + if (*inter_cost < INT_MAX) { const vector2d_t orig = { x, y }; if (cur_cu->inter.mv_dir == 1) { assert(fracmv_within_tile(state, &orig, cur_cu->inter.mv[0][0], cur_cu->inter.mv[0][1], width, height, -1)); } } - - return inter_cost; } @@ -1518,11 +1521,21 @@ static int search_pu_inter(encoder_state_t * const state, * \param depth depth of the CU in the quadtree * \param lcu containing LCU * - * \return cost of the best mode + * \param inter_cost Return inter cost + * \param inter_bitcost Return inter bitcost */ -int kvz_search_cu_inter(encoder_state_t * const state, int x, int y, int depth, lcu_t *lcu) +void kvz_search_cu_inter(encoder_state_t * const state, + int x, int y, int depth, + lcu_t *lcu, + double *inter_cost, + uint32_t *inter_bitcost) { - return search_pu_inter(state, x, y, depth, SIZE_2Nx2N, 0, lcu); + search_pu_inter(state, + x, y, depth, + SIZE_2Nx2N, 0, + lcu, + inter_cost, + inter_bitcost); } @@ -1538,20 +1551,25 @@ int kvz_search_cu_inter(encoder_state_t * const state, int x, int y, int depth, * \param part_mode partition mode to search * \param lcu containing LCU * - * \return cost of the best mode + * \param inter_cost Return inter cost + * \param inter_bitcost Return inter bitcost */ -int kvz_search_cu_smp(encoder_state_t * const state, - int x, int y, - int depth, - part_mode_t part_mode, - lcu_t *lcu) +void kvz_search_cu_smp(encoder_state_t * const state, + int x, int y, + int depth, + part_mode_t part_mode, + lcu_t *lcu, + double *inter_cost, + uint32_t *inter_bitcost) { const int num_pu = kvz_part_mode_num_parts[part_mode]; const int width_scu = (LCU_WIDTH >> depth) >> MAX_DEPTH; const int y_scu = SUB_SCU(y) >> MAX_DEPTH; const int x_scu = SUB_SCU(x) >> MAX_DEPTH; - int cost = 0; + *inter_cost = 0; + *inter_bitcost = 0; + for (int i = 0; i < num_pu; ++i) { const int x_pu = PU_GET_X(part_mode, width_scu, x_scu, i); const int y_pu = PU_GET_Y(part_mode, width_scu, y_scu, i); @@ -1559,11 +1577,17 @@ int kvz_search_cu_smp(encoder_state_t * const state, const int height_pu = PU_GET_H(part_mode, width_scu, i); cu_info_t *cur_pu = LCU_GET_CU(lcu, x_pu, y_pu); - cur_pu->type = CU_INTER; + cur_pu->type = CU_INTER; cur_pu->part_size = part_mode; - cur_pu->depth = depth; + cur_pu->depth = depth; - cost += search_pu_inter(state, x, y, depth, part_mode, i, lcu); + double cost = MAX_INT; + uint32_t bitcost = MAX_INT; + + search_pu_inter(state, x, y, depth, part_mode, i, lcu, &cost, &bitcost); + + *inter_cost += cost; + *inter_bitcost += bitcost; for (int y = y_pu; y < y_pu + height_pu; ++y) { for (int x = x_pu; x < x_pu + width_pu; ++x) { @@ -1573,6 +1597,4 @@ int kvz_search_cu_smp(encoder_state_t * const state, } } } - - return cost; } diff --git a/src/search_inter.h b/src/search_inter.h index 780145e4..1b7c7632 100644 --- a/src/search_inter.h +++ b/src/search_inter.h @@ -31,12 +31,18 @@ #include "global.h" // IWYU pragma: keep -int kvz_search_cu_inter(encoder_state_t * const state, int x, int y, int depth, lcu_t *lcu); +void kvz_search_cu_inter(encoder_state_t * const state, + int x, int y, int depth, + lcu_t *lcu, + double *inter_cost, + uint32_t *inter_bitcost); -int kvz_search_cu_smp(encoder_state_t * const state, - int x, int y, - int depth, - part_mode_t part_mode, - lcu_t *lcu); +void kvz_search_cu_smp(encoder_state_t * const state, + int x, int y, + int depth, + part_mode_t part_mode, + lcu_t *lcu, + double *inter_cost, + uint32_t *inter_bitcost); #endif // SEARCH_INTER_H_