Add new structs for storing statistics during the search. Use in AMVP search.

This commit is contained in:
Ari Lemmetti 2021-11-28 23:40:16 +02:00
parent 936fb76685
commit 90c0a708a7
3 changed files with 96 additions and 30 deletions

View file

@ -415,6 +415,7 @@ static double calc_mode_bits(const encoder_state_t *state,
} }
// TODO: replace usages of this by the kvz_sort_indices_by_cost function.
/** /**
* \brief Sort modes and costs to ascending order according to costs. * \brief Sort modes and costs to ascending order according to costs.
*/ */
@ -439,6 +440,25 @@ void kvz_sort_modes(int8_t *__restrict modes, double *__restrict costs, uint8_t
} }
/**
* \brief Sort indices to ascending order according to costs.
*/
void kvz_sort_indices_by_cost(blk_stats_map_t *__restrict map)
{
// Size of sorted arrays is expected to be "small". No need for faster algorithm.
for (uint8_t i = 1; i < map->size; ++i) {
const int8_t cur_idx = map->idx[i];
const double cur_cost = map->stats[cur_idx].cost;
uint8_t j = i;
while (j > 0 && cur_cost < map->stats[map->idx[j - 1]].cost) {
map->idx[j] = map->idx[j - 1];
--j;
}
map->idx[j] = cur_idx;
}
}
static uint8_t get_ctx_cu_split_model(const lcu_t *lcu, int x, int y, int depth) static uint8_t get_ctx_cu_split_model(const lcu_t *lcu, int x, int y, int depth)
{ {
vector2d_t lcu_cu = { SUB_SCU(x), SUB_SCU(y) }; vector2d_t lcu_cu = { SUB_SCU(x), SUB_SCU(y) };

View file

@ -44,7 +44,22 @@
#include "image.h" #include "image.h"
#include "constraint.h" #include "constraint.h"
typedef struct blk_stats_t {
cu_info_t blk; // list of blocks
double cost; // list of RD costs
uint32_t bits; // list of bit costs
} blk_stats_t;
typedef struct blk_stats_map_t {
blk_stats_t *stats; // list of block statistics entries
int8_t *idx; // list of indices to block stats (to be sorted by costs)
int size; // number of active elements in the lists
} blk_stats_map_t;
void kvz_sort_modes(int8_t *__restrict modes, double *__restrict costs, uint8_t length); void kvz_sort_modes(int8_t *__restrict modes, double *__restrict costs, uint8_t length);
void kvz_sort_indices_by_cost(blk_stats_map_t *__restrict map);
void kvz_search_lcu(encoder_state_t *state, int x, int y, const yuv_t *hor_buf, const yuv_t *ver_buf); void kvz_search_lcu(encoder_state_t *state, int x, int y, const yuv_t *hor_buf, const yuv_t *ver_buf);

View file

@ -1215,11 +1215,12 @@ static void apply_mv_scaling(int32_t current_poc,
*/ */
static void search_pu_inter_ref(inter_search_info_t *info, static void search_pu_inter_ref(inter_search_info_t *info,
int depth, int depth,
lcu_t *lcu, cu_info_t *cur_cu, lcu_t *lcu,
cu_info_t *cur_cu,
double *inter_cost, double *inter_cost,
uint32_t *inter_bitcost, uint32_t *inter_bitcost,
double *best_LX_cost, double *best_LX_cost,
cu_info_t *unipred_LX) blk_stats_map_t *amvp)
{ {
const kvz_config *cfg = &info->state->encoder_control->cfg; const kvz_config *cfg = &info->state->encoder_control->cfg;
@ -1409,15 +1410,23 @@ static void search_pu_inter_ref(inter_search_info_t *info,
bool valid_mv = fracmv_within_tile(info, mv.x, mv.y); bool valid_mv = fracmv_within_tile(info, mv.x, mv.y);
if (valid_mv) { if (valid_mv) {
// Map reference index to L0/L1 pictures // Map reference index to L0/L1 pictures
unipred_LX[ref_list].merged = false; blk_stats_map_t *cur_map = &amvp[ref_list];
unipred_LX[ref_list].skipped = false; blk_stats_t *entry = &cur_map->stats[cur_map->size];
unipred_LX[ref_list].inter.mv_dir = ref_list + 1; cu_info_t *pb = &entry->blk;
unipred_LX[ref_list].inter.mv_ref[ref_list] = LX_idx; pb->merged = false;
unipred_LX[ref_list].inter.mv[ref_list][0] = (int16_t)mv.x; pb->skipped = false;
unipred_LX[ref_list].inter.mv[ref_list][1] = (int16_t)mv.y; pb->inter.mv_dir = ref_list + 1;
pb->inter.mv_ref[ref_list] = LX_idx;
pb->inter.mv[ref_list][0] = (int16_t)mv.x;
pb->inter.mv[ref_list][1] = (int16_t)mv.y;
CU_SET_MV_CAND(&unipred_LX[ref_list], ref_list, cu_mv_cand); CU_SET_MV_CAND(pb, ref_list, cu_mv_cand);
entry->cost = info->best_cost;
entry->bits = info->best_bitcost;
cur_map->size++;
// TODO: remove (this is just to keep old functionality)
best_LX_cost[ref_list] = info->best_cost; best_LX_cost[ref_list] = info->best_cost;
} }
} }
@ -1669,6 +1678,7 @@ static void search_pu_inter(encoder_state_t * const state,
mrg_costs[i] = MAX_DOUBLE; mrg_costs[i] = MAX_DOUBLE;
} }
cu_info_t orig_cu = *cur_cu;
int num_rdo_cands = 0; int num_rdo_cands = 0;
// Check motion vector constraints and perform rough search // Check motion vector constraints and perform rough search
@ -1765,16 +1775,31 @@ static void search_pu_inter(encoder_state_t * const state,
// Store unipred information of L0 and L1 for biprediction // Store unipred information of L0 and L1 for biprediction
// Best cost will be left at MAX_DOUBLE if no valid CU is found // Best cost will be left at MAX_DOUBLE if no valid CU is found
double best_cost_LX[2] = { MAX_DOUBLE, MAX_DOUBLE }; double best_cost_LX[2] = { MAX_DOUBLE, MAX_DOUBLE }; // TODO: remove
cu_info_t unipreds[2]; blk_stats_t stats[2][MAX_REF_PIC_COUNT];
int8_t idx[2][MAX_REF_PIC_COUNT];
blk_stats_map_t amvp[2];
for (int ref_list = 0; ref_list < 2; ++ref_list) {
amvp[ref_list].stats = stats[ref_list];
amvp[ref_list].idx = idx [ref_list];
amvp[ref_list].size = 0;
for (int i = 0; i < MAX_REF_PIC_COUNT; ++i) {
amvp[ref_list].stats[i].blk = orig_cu;
amvp[ref_list].idx[i] = i;
}
}
for (int ref_idx = 0; ref_idx < state->frame->ref->used_size; ref_idx++) { for (int ref_idx = 0; ref_idx < state->frame->ref->used_size; ref_idx++) {
info.ref_idx = ref_idx; info.ref_idx = ref_idx;
info.ref = state->frame->ref->images[ref_idx]; info.ref = state->frame->ref->images[ref_idx];
search_pu_inter_ref(&info, depth, lcu, cur_cu, inter_cost, inter_bitcost, best_cost_LX, unipreds); search_pu_inter_ref(&info, depth, lcu, cur_cu, inter_cost, inter_bitcost, best_cost_LX, amvp);
} }
kvz_sort_indices_by_cost(&amvp[0]);
kvz_sort_indices_by_cost(&amvp[1]);
// Search bi-pred positions // Search bi-pred positions
bool can_use_bipred = state->frame->slicetype == KVZ_SLICE_B bool can_use_bipred = state->frame->slicetype == KVZ_SLICE_B
&& cfg->bipred && cfg->bipred
@ -1792,15 +1817,21 @@ static void search_pu_inter(encoder_state_t * const state,
inter_merge_cand_t *merge_cand = info.merge_cand; inter_merge_cand_t *merge_cand = info.merge_cand;
int best_idx[2] = { amvp[0].idx[0], amvp[1].idx[0] };
cu_info_t *best_unipred[2] = {
&amvp[0].stats[best_idx[0]].blk,
&amvp[1].stats[best_idx[1]].blk
};
int16_t mv[2][2]; int16_t mv[2][2];
mv[0][0] = unipreds[0].inter.mv[0][0]; mv[0][0] = best_unipred[0]->inter.mv[0][0];
mv[0][1] = unipreds[0].inter.mv[0][1]; mv[0][1] = best_unipred[0]->inter.mv[0][1];
mv[1][0] = unipreds[1].inter.mv[1][0]; mv[1][0] = best_unipred[1]->inter.mv[1][0];
mv[1][1] = unipreds[1].inter.mv[1][1]; mv[1][1] = best_unipred[1]->inter.mv[1][1];
kvz_inter_recon_bipred(info.state, kvz_inter_recon_bipred(info.state,
ref->images[ref_LX[0][unipreds[0].inter.mv_ref[0]]], ref->images[ref_LX[0][best_unipred[0]->inter.mv_ref[0]]],
ref->images[ref_LX[1][unipreds[1].inter.mv_ref[1]]], ref->images[ref_LX[1][best_unipred[1]->inter.mv_ref[1]]],
x, y, x, y,
width, width,
height, height,
@ -1817,23 +1848,23 @@ static void search_pu_inter(encoder_state_t * const state,
uint32_t bitcost[2] = { 0, 0 }; uint32_t bitcost[2] = { 0, 0 };
cost += info.mvd_cost_func(info.state, cost += info.mvd_cost_func(info.state,
unipreds[0].inter.mv[0][0], best_unipred[0]->inter.mv[0][0],
unipreds[0].inter.mv[0][1], best_unipred[0]->inter.mv[0][1],
0, 0,
info.mv_cand, info.mv_cand,
NULL, 0, 0, NULL, 0, 0,
&bitcost[0]); &bitcost[0]);
cost += info.mvd_cost_func(info.state, cost += info.mvd_cost_func(info.state,
unipreds[1].inter.mv[1][0], best_unipred[1]->inter.mv[1][0],
unipreds[1].inter.mv[1][1], best_unipred[1]->inter.mv[1][1],
0, 0,
info.mv_cand, info.mv_cand,
NULL, 0, 0, NULL, 0, 0,
&bitcost[1]); &bitcost[1]);
const uint8_t mv_ref_coded[2] = { const uint8_t mv_ref_coded[2] = {
unipreds[0].inter.mv_ref[0], best_unipred[0]->inter.mv_ref[0],
unipreds[1].inter.mv_ref[1] best_unipred[1]->inter.mv_ref[1]
}; };
const int extra_bits = mv_ref_coded[0] + mv_ref_coded[1] + 2 /* mv dir cost */; const int extra_bits = mv_ref_coded[0] + mv_ref_coded[1] + 2 /* mv dir cost */;
cost += info.state->lambda_sqrt * extra_bits; cost += info.state->lambda_sqrt * extra_bits;
@ -1841,13 +1872,13 @@ static void search_pu_inter(encoder_state_t * const state,
if (cost < *inter_cost) { if (cost < *inter_cost) {
cur_cu->inter.mv_dir = 3; cur_cu->inter.mv_dir = 3;
cur_cu->inter.mv_ref[0] = unipreds[0].inter.mv_ref[0]; cur_cu->inter.mv_ref[0] = best_unipred[0]->inter.mv_ref[0];
cur_cu->inter.mv_ref[1] = unipreds[1].inter.mv_ref[1]; cur_cu->inter.mv_ref[1] = best_unipred[1]->inter.mv_ref[1];
cur_cu->inter.mv[0][0] = unipreds[0].inter.mv[0][0]; cur_cu->inter.mv[0][0] = best_unipred[0]->inter.mv[0][0];
cur_cu->inter.mv[0][1] = unipreds[0].inter.mv[0][1]; cur_cu->inter.mv[0][1] = best_unipred[0]->inter.mv[0][1];
cur_cu->inter.mv[1][0] = unipreds[1].inter.mv[1][0]; cur_cu->inter.mv[1][0] = best_unipred[1]->inter.mv[1][0];
cur_cu->inter.mv[1][1] = unipreds[1].inter.mv[1][1]; cur_cu->inter.mv[1][1] = best_unipred[1]->inter.mv[1][1];
cur_cu->merged = 0; cur_cu->merged = 0;
// Check every candidate to find a match // Check every candidate to find a match