diff --git a/src/search.c b/src/search.c index 4345ad75..385c4981 100644 --- a/src/search.c +++ b/src/search.c @@ -415,6 +415,7 @@ static double calc_mode_bits(const encoder_state_t *state, } +// TODO: replace usages of this by the kvz_sort_indices_by_cost function. /** * \brief Sort modes and costs to ascending order according to costs. */ @@ -439,6 +440,25 @@ void kvz_sort_modes(int8_t *__restrict modes, double *__restrict costs, uint8_t } +/** + * \brief Sort indices to ascending order according to costs. + */ +void kvz_sort_indices_by_cost(blk_stats_map_t *__restrict map) +{ + // Size of sorted arrays is expected to be "small". No need for faster algorithm. + for (uint8_t i = 1; i < map->size; ++i) { + const int8_t cur_idx = map->idx[i]; + const double cur_cost = map->stats[cur_idx].cost; + uint8_t j = i; + while (j > 0 && cur_cost < map->stats[map->idx[j - 1]].cost) { + map->idx[j] = map->idx[j - 1]; + --j; + } + map->idx[j] = cur_idx; + } +} + + static uint8_t get_ctx_cu_split_model(const lcu_t *lcu, int x, int y, int depth) { vector2d_t lcu_cu = { SUB_SCU(x), SUB_SCU(y) }; diff --git a/src/search.h b/src/search.h index 774a4d7b..fe6d7f5d 100644 --- a/src/search.h +++ b/src/search.h @@ -44,7 +44,22 @@ #include "image.h" #include "constraint.h" +typedef struct blk_stats_t { + + cu_info_t blk; // list of blocks + double cost; // list of RD costs + uint32_t bits; // list of bit costs +} blk_stats_t; + +typedef struct blk_stats_map_t { + + blk_stats_t *stats; // list of block statistics entries + int8_t *idx; // list of indices to block stats (to be sorted by costs) + int size; // number of active elements in the lists +} blk_stats_map_t; + void kvz_sort_modes(int8_t *__restrict modes, double *__restrict costs, uint8_t length); +void kvz_sort_indices_by_cost(blk_stats_map_t *__restrict map); void kvz_search_lcu(encoder_state_t *state, int x, int y, const yuv_t *hor_buf, const yuv_t *ver_buf); diff --git a/src/search_inter.c b/src/search_inter.c index f091c260..d561387a 100644 --- a/src/search_inter.c +++ b/src/search_inter.c @@ -1215,11 +1215,12 @@ static void apply_mv_scaling(int32_t current_poc, */ static void search_pu_inter_ref(inter_search_info_t *info, int depth, - lcu_t *lcu, cu_info_t *cur_cu, + lcu_t *lcu, + cu_info_t *cur_cu, double *inter_cost, uint32_t *inter_bitcost, double *best_LX_cost, - cu_info_t *unipred_LX) + blk_stats_map_t *amvp) { const kvz_config *cfg = &info->state->encoder_control->cfg; @@ -1409,15 +1410,23 @@ static void search_pu_inter_ref(inter_search_info_t *info, bool valid_mv = fracmv_within_tile(info, mv.x, mv.y); if (valid_mv) { // Map reference index to L0/L1 pictures - unipred_LX[ref_list].merged = false; - unipred_LX[ref_list].skipped = false; - unipred_LX[ref_list].inter.mv_dir = ref_list + 1; - unipred_LX[ref_list].inter.mv_ref[ref_list] = LX_idx; - unipred_LX[ref_list].inter.mv[ref_list][0] = (int16_t)mv.x; - unipred_LX[ref_list].inter.mv[ref_list][1] = (int16_t)mv.y; + blk_stats_map_t *cur_map = &amvp[ref_list]; + blk_stats_t *entry = &cur_map->stats[cur_map->size]; + cu_info_t *pb = &entry->blk; + pb->merged = false; + pb->skipped = false; + pb->inter.mv_dir = ref_list + 1; + pb->inter.mv_ref[ref_list] = LX_idx; + pb->inter.mv[ref_list][0] = (int16_t)mv.x; + pb->inter.mv[ref_list][1] = (int16_t)mv.y; - CU_SET_MV_CAND(&unipred_LX[ref_list], ref_list, cu_mv_cand); + CU_SET_MV_CAND(pb, ref_list, cu_mv_cand); + entry->cost = info->best_cost; + entry->bits = info->best_bitcost; + cur_map->size++; + + // TODO: remove (this is just to keep old functionality) best_LX_cost[ref_list] = info->best_cost; } } @@ -1669,6 +1678,7 @@ static void search_pu_inter(encoder_state_t * const state, mrg_costs[i] = MAX_DOUBLE; } + cu_info_t orig_cu = *cur_cu; int num_rdo_cands = 0; // Check motion vector constraints and perform rough search @@ -1765,16 +1775,31 @@ static void search_pu_inter(encoder_state_t * const state, // Store unipred information of L0 and L1 for biprediction // Best cost will be left at MAX_DOUBLE if no valid CU is found - double best_cost_LX[2] = { MAX_DOUBLE, MAX_DOUBLE }; - cu_info_t unipreds[2]; + double best_cost_LX[2] = { MAX_DOUBLE, MAX_DOUBLE }; // TODO: remove + blk_stats_t stats[2][MAX_REF_PIC_COUNT]; + int8_t idx[2][MAX_REF_PIC_COUNT]; + blk_stats_map_t amvp[2]; + + for (int ref_list = 0; ref_list < 2; ++ref_list) { + amvp[ref_list].stats = stats[ref_list]; + amvp[ref_list].idx = idx [ref_list]; + amvp[ref_list].size = 0; + for (int i = 0; i < MAX_REF_PIC_COUNT; ++i) { + amvp[ref_list].stats[i].blk = orig_cu; + amvp[ref_list].idx[i] = i; + } + } for (int ref_idx = 0; ref_idx < state->frame->ref->used_size; ref_idx++) { info.ref_idx = ref_idx; info.ref = state->frame->ref->images[ref_idx]; - search_pu_inter_ref(&info, depth, lcu, cur_cu, inter_cost, inter_bitcost, best_cost_LX, unipreds); + search_pu_inter_ref(&info, depth, lcu, cur_cu, inter_cost, inter_bitcost, best_cost_LX, amvp); } + kvz_sort_indices_by_cost(&amvp[0]); + kvz_sort_indices_by_cost(&amvp[1]); + // Search bi-pred positions bool can_use_bipred = state->frame->slicetype == KVZ_SLICE_B && cfg->bipred @@ -1792,15 +1817,21 @@ static void search_pu_inter(encoder_state_t * const state, inter_merge_cand_t *merge_cand = info.merge_cand; + int best_idx[2] = { amvp[0].idx[0], amvp[1].idx[0] }; + cu_info_t *best_unipred[2] = { + &amvp[0].stats[best_idx[0]].blk, + &amvp[1].stats[best_idx[1]].blk + }; + int16_t mv[2][2]; - mv[0][0] = unipreds[0].inter.mv[0][0]; - mv[0][1] = unipreds[0].inter.mv[0][1]; - mv[1][0] = unipreds[1].inter.mv[1][0]; - mv[1][1] = unipreds[1].inter.mv[1][1]; + mv[0][0] = best_unipred[0]->inter.mv[0][0]; + mv[0][1] = best_unipred[0]->inter.mv[0][1]; + mv[1][0] = best_unipred[1]->inter.mv[1][0]; + mv[1][1] = best_unipred[1]->inter.mv[1][1]; kvz_inter_recon_bipred(info.state, - ref->images[ref_LX[0][unipreds[0].inter.mv_ref[0]]], - ref->images[ref_LX[1][unipreds[1].inter.mv_ref[1]]], + ref->images[ref_LX[0][best_unipred[0]->inter.mv_ref[0]]], + ref->images[ref_LX[1][best_unipred[1]->inter.mv_ref[1]]], x, y, width, height, @@ -1817,23 +1848,23 @@ static void search_pu_inter(encoder_state_t * const state, uint32_t bitcost[2] = { 0, 0 }; cost += info.mvd_cost_func(info.state, - unipreds[0].inter.mv[0][0], - unipreds[0].inter.mv[0][1], + best_unipred[0]->inter.mv[0][0], + best_unipred[0]->inter.mv[0][1], 0, info.mv_cand, NULL, 0, 0, &bitcost[0]); cost += info.mvd_cost_func(info.state, - unipreds[1].inter.mv[1][0], - unipreds[1].inter.mv[1][1], + best_unipred[1]->inter.mv[1][0], + best_unipred[1]->inter.mv[1][1], 0, info.mv_cand, NULL, 0, 0, &bitcost[1]); const uint8_t mv_ref_coded[2] = { - unipreds[0].inter.mv_ref[0], - unipreds[1].inter.mv_ref[1] + best_unipred[0]->inter.mv_ref[0], + best_unipred[1]->inter.mv_ref[1] }; const int extra_bits = mv_ref_coded[0] + mv_ref_coded[1] + 2 /* mv dir cost */; cost += info.state->lambda_sqrt * extra_bits; @@ -1841,13 +1872,13 @@ static void search_pu_inter(encoder_state_t * const state, if (cost < *inter_cost) { cur_cu->inter.mv_dir = 3; - cur_cu->inter.mv_ref[0] = unipreds[0].inter.mv_ref[0]; - cur_cu->inter.mv_ref[1] = unipreds[1].inter.mv_ref[1]; + cur_cu->inter.mv_ref[0] = best_unipred[0]->inter.mv_ref[0]; + cur_cu->inter.mv_ref[1] = best_unipred[1]->inter.mv_ref[1]; - cur_cu->inter.mv[0][0] = unipreds[0].inter.mv[0][0]; - cur_cu->inter.mv[0][1] = unipreds[0].inter.mv[0][1]; - cur_cu->inter.mv[1][0] = unipreds[1].inter.mv[1][0]; - cur_cu->inter.mv[1][1] = unipreds[1].inter.mv[1][1]; + cur_cu->inter.mv[0][0] = best_unipred[0]->inter.mv[0][0]; + cur_cu->inter.mv[0][1] = best_unipred[0]->inter.mv[0][1]; + cur_cu->inter.mv[1][0] = best_unipred[1]->inter.mv[1][0]; + cur_cu->inter.mv[1][1] = best_unipred[1]->inter.mv[1][1]; cur_cu->merged = 0; // Check every candidate to find a match