Refactor inter MV candidate selection

Moves duplicate code for checking the best MV candidate from functions
calc_mvd_cost, search_pu_inter_ref and search_pu_inter to a new
function.
This commit is contained in:
Arttu Ylä-Outinen 2018-01-18 12:47:27 +02:00
parent 774c666528
commit c1cca1ad7f
3 changed files with 86 additions and 91 deletions

View file

@ -852,17 +852,15 @@ void kvz_rdoq(encoder_state_t * const state, coeff_t *coef, coeff_t *dest_coeff,
} }
} }
/** MVD cost calculation with CABAC /**
* \returns int * Calculate cost of actual motion vectors using CABAC coding
* Calculates cost of actual motion vectors using CABAC coding */
*/
uint32_t kvz_get_mvd_coding_cost_cabac(const encoder_state_t *state, uint32_t kvz_get_mvd_coding_cost_cabac(const encoder_state_t *state,
vector2d_t *mvd, const cabac_data_t* real_cabac,
const cabac_data_t* real_cabac) int32_t mvd_hor,
int32_t mvd_ver)
{ {
uint32_t bitcost = 0; uint32_t bitcost = 0;
const int32_t mvd_hor = mvd->x;
const int32_t mvd_ver = mvd->y;
const int8_t hor_abs_gr0 = mvd_hor != 0; const int8_t hor_abs_gr0 = mvd_hor != 0;
const int8_t ver_abs_gr0 = mvd_ver != 0; const int8_t ver_abs_gr0 = mvd_ver != 0;
const uint32_t mvd_hor_abs = abs(mvd_hor); const uint32_t mvd_hor_abs = abs(mvd_hor);
@ -919,8 +917,7 @@ uint32_t kvz_calc_mvd_cost_cabac(const encoder_state_t * state,
cabac_data_t state_cabac_copy; cabac_data_t state_cabac_copy;
cabac_data_t* cabac; cabac_data_t* cabac;
uint32_t merge_idx; uint32_t merge_idx;
int cand1_cost, cand2_cost; vector2d_t mvd = { 0, 0 };
vector2d_t mvd_temp1, mvd_temp2, mvd = { 0, 0 };
int8_t merged = 0; int8_t merged = 0;
int8_t cur_mv_cand = 0; int8_t cur_mv_cand = 0;
@ -952,20 +949,23 @@ uint32_t kvz_calc_mvd_cost_cabac(const encoder_state_t * state,
cabac = &state_cabac_copy; cabac = &state_cabac_copy;
if (!merged) { if (!merged) {
mvd_temp1.x = x - mv_cand[0][0]; vector2d_t mvd1 = {
mvd_temp1.y = y - mv_cand[0][1]; x - mv_cand[0][0],
cand1_cost = kvz_get_mvd_coding_cost_cabac(state, &mvd_temp1, cabac); y - mv_cand[0][1],
};
mvd_temp2.x = x - mv_cand[1][0]; vector2d_t mvd2 = {
mvd_temp2.y = y - mv_cand[1][1]; x - mv_cand[1][0],
cand2_cost = kvz_get_mvd_coding_cost_cabac(state, &mvd_temp2, cabac); y - mv_cand[1][1],
};
uint32_t cand1_cost = kvz_get_mvd_coding_cost_cabac(state, cabac, mvd1.x, mvd1.y);
uint32_t cand2_cost = kvz_get_mvd_coding_cost_cabac(state, cabac, mvd2.x, mvd2.y);
// Select candidate 1 if it has lower cost // Select candidate 1 if it has lower cost
if (cand2_cost < cand1_cost) { if (cand2_cost < cand1_cost) {
cur_mv_cand = 1; cur_mv_cand = 1;
mvd = mvd_temp2; mvd = mvd2;
} else { } else {
mvd = mvd_temp1; mvd = mvd1;
} }
} }

View file

@ -58,8 +58,9 @@ uint32_t kvz_get_coded_level(encoder_state_t * state, double* coded_cost, double
kvz_mvd_cost_func kvz_calc_mvd_cost_cabac; kvz_mvd_cost_func kvz_calc_mvd_cost_cabac;
uint32_t kvz_get_mvd_coding_cost_cabac(const encoder_state_t *state, uint32_t kvz_get_mvd_coding_cost_cabac(const encoder_state_t *state,
vector2d_t *mvd, const cabac_data_t* cabac,
const cabac_data_t* cabac); int32_t mvd_hor,
int32_t mvd_ver);
// Number of fixed point fractional bits used in the fractional bit table. // Number of fixed point fractional bits used in the fractional bit table.
#define CTX_FRAC_BITS 15 #define CTX_FRAC_BITS 15

View file

@ -307,11 +307,12 @@ static void select_starting_point(inter_search_info_t *info, vector2d_t extra_mv
static uint32_t get_mvd_coding_cost(const encoder_state_t *state, static uint32_t get_mvd_coding_cost(const encoder_state_t *state,
vector2d_t *mvd, const cabac_data_t* cabac,
const cabac_data_t* cabac) const int32_t mvd_hor,
const int32_t mvd_ver)
{ {
unsigned bitcost = 0; unsigned bitcost = 0;
const vector2d_t abs_mvd = { abs(mvd->x), abs(mvd->y) }; const vector2d_t abs_mvd = { abs(mvd_hor), abs(mvd_ver) };
bitcost += CTX_ENTROPY_BITS(&cabac->ctx.cu_mvd_model[0], abs_mvd.x > 0); bitcost += CTX_ENTROPY_BITS(&cabac->ctx.cu_mvd_model[0], abs_mvd.x > 0);
if (abs_mvd.x > 0) { if (abs_mvd.x > 0) {
@ -336,6 +337,53 @@ static uint32_t get_mvd_coding_cost(const encoder_state_t *state,
} }
static int select_mv_cand(const encoder_state_t *state,
int16_t mv_cand[2][2],
int32_t mv_x,
int32_t mv_y,
uint32_t *cost_out)
{
const bool same_cand =
(mv_cand[0][0] == mv_cand[1][0] && mv_cand[0][1] == mv_cand[1][1]);
if (same_cand && !cost_out) {
// Pick the first one if both candidates are the same.
return 0;
}
uint32_t (*mvd_coding_cost)(const encoder_state_t * const state,
const cabac_data_t*,
int32_t, int32_t);
if (state->encoder_control->cfg.mv_rdo) {
mvd_coding_cost = kvz_get_mvd_coding_cost_cabac;
} else {
mvd_coding_cost = get_mvd_coding_cost;
}
uint32_t cand1_cost = mvd_coding_cost(
state, &state->cabac,
mv_x - mv_cand[0][0],
mv_y - mv_cand[0][1]);
uint32_t cand2_cost;
if (same_cand) {
cand2_cost = cand1_cost;
} else {
cand2_cost = mvd_coding_cost(
state, &state->cabac,
mv_x - mv_cand[1][0],
mv_y - mv_cand[1][1]);
}
if (cost_out) {
*cost_out = MIN(cand1_cost, cand2_cost);
}
// Pick the second candidate if it has lower cost.
return cand2_cost < cand1_cost ? 1 : 0;
}
static uint32_t calc_mvd_cost(const encoder_state_t *state, static uint32_t calc_mvd_cost(const encoder_state_t *state,
int x, int x,
int y, int y,
@ -348,10 +396,7 @@ static uint32_t calc_mvd_cost(const encoder_state_t *state,
{ {
uint32_t temp_bitcost = 0; uint32_t temp_bitcost = 0;
uint32_t merge_idx; uint32_t merge_idx;
int cand1_cost,cand2_cost;
vector2d_t mvd_temp1, mvd_temp2;
int8_t merged = 0; int8_t merged = 0;
int8_t cur_mv_cand = 0;
x *= 1 << mv_shift; x *= 1 << mv_shift;
y *= 1 << mv_shift; y *= 1 << mv_shift;
@ -371,20 +416,10 @@ static uint32_t calc_mvd_cost(const encoder_state_t *state,
} }
// Check mvd cost only if mv is not merged // Check mvd cost only if mv is not merged
if(!merged) { if (!merged) {
mvd_temp1.x = x - mv_cand[0][0]; uint32_t mvd_cost = 0;
mvd_temp1.y = y - mv_cand[0][1]; select_mv_cand(state, mv_cand, x, y, &mvd_cost);
cand1_cost = get_mvd_coding_cost(state, &mvd_temp1, &state->cabac); temp_bitcost += mvd_cost;
mvd_temp2.x = x - mv_cand[1][0];
mvd_temp2.y = y - mv_cand[1][1];
cand2_cost = get_mvd_coding_cost(state, &mvd_temp2, &state->cabac);
// Select candidate 1 if it has lower cost
if (cand2_cost < cand1_cost) {
cur_mv_cand = 1;
}
temp_bitcost += cur_mv_cand ? cand2_cost : cand1_cost;
} }
*bitcost = temp_bitcost; *bitcost = temp_bitcost;
return temp_bitcost*(int32_t)(state->lambda_sqrt + 0.5); return temp_bitcost*(int32_t)(state->lambda_sqrt + 0.5);
@ -1315,30 +1350,9 @@ static void search_pu_inter_ref(inter_search_info_t *info,
// Only check when candidates are different // Only check when candidates are different
int cu_mv_cand = 0; int cu_mv_cand = 0;
if (!merged && ( if (!merged) {
info->mv_cand[0][0] != info->mv_cand[1][0] || cu_mv_cand =
info->mv_cand[0][1] != info->mv_cand[1][1])) select_mv_cand(info->state, info->mv_cand, mv.x, mv.y, NULL);
{
uint32_t (*mvd_coding_cost)(const encoder_state_t * const state,
vector2d_t *,
const cabac_data_t*) =
cfg->mv_rdo ? kvz_get_mvd_coding_cost_cabac : get_mvd_coding_cost;
vector2d_t mvd_temp1, mvd_temp2;
int cand1_cost,cand2_cost;
mvd_temp1.x = mv.x - info->mv_cand[0][0];
mvd_temp1.y = mv.y - info->mv_cand[0][1];
cand1_cost = mvd_coding_cost(info->state, &mvd_temp1, &info->state->cabac);
mvd_temp2.x = mv.x - info->mv_cand[1][0];
mvd_temp2.y = mv.y - info->mv_cand[1][1];
cand2_cost = mvd_coding_cost(info->state, &mvd_temp2, &info->state->cabac);
// Select candidate 1 if it has lower cost
if (cand2_cost < cand1_cost) {
cu_mv_cand = 1;
}
} }
if (info->best_cost < *inter_cost) { if (info->best_cost < *inter_cost) {
@ -1465,7 +1479,6 @@ static void search_pu_inter(encoder_state_t * const state,
{ {
uint32_t bitcost[2]; uint32_t bitcost[2];
uint32_t cost = 0; uint32_t cost = 0;
int8_t cu_mv_cand = 0;
int16_t mv[2][2]; int16_t mv[2][2];
kvz_pixel tmp_block[64 * 64]; kvz_pixel tmp_block[64 * 64];
kvz_pixel tmp_pic[64 * 64]; kvz_pixel tmp_pic[64 * 64];
@ -1558,32 +1571,13 @@ static void search_pu_inter(encoder_state_t * const state,
// Each motion vector has its own candidate // Each motion vector has its own candidate
for (int reflist = 0; reflist < 2; reflist++) { for (int reflist = 0; reflist < 2; reflist++) {
cu_mv_cand = 0;
kvz_inter_get_mv_cand(state, x, y, width, height, info.mv_cand, cur_cu, lcu, reflist); kvz_inter_get_mv_cand(state, x, y, width, height, info.mv_cand, cur_cu, lcu, reflist);
if (info.mv_cand[0][0] != info.mv_cand[1][0] || int cu_mv_cand = select_mv_cand(
info.mv_cand[0][1] != info.mv_cand[1][1]) state,
{ info.mv_cand,
uint32_t (*mvd_coding_cost)(const encoder_state_t * const state, cur_cu->inter.mv[reflist][0],
vector2d_t *, cur_cu->inter.mv[reflist][1],
const cabac_data_t*) = NULL);
cfg->mv_rdo ? kvz_get_mvd_coding_cost_cabac : get_mvd_coding_cost;
vector2d_t mvd_temp1, mvd_temp2;
int cand1_cost, cand2_cost;
mvd_temp1.x = cur_cu->inter.mv[reflist][0] - info.mv_cand[0][0];
mvd_temp1.y = cur_cu->inter.mv[reflist][1] - info.mv_cand[0][1];
cand1_cost = mvd_coding_cost(state, &mvd_temp1, (cabac_data_t*)&state->cabac);
mvd_temp2.x = cur_cu->inter.mv[reflist][0] - info.mv_cand[1][0];
mvd_temp2.y = cur_cu->inter.mv[reflist][1] - info.mv_cand[1][1];
cand2_cost = mvd_coding_cost(state, &mvd_temp2, (cabac_data_t*)&state->cabac);
// Select candidate 1 if it has lower cost
if (cand2_cost < cand1_cost) {
cu_mv_cand = 1;
}
}
CU_SET_MV_CAND(cur_cu, reflist, cu_mv_cand); CU_SET_MV_CAND(cur_cu, reflist, cu_mv_cand);
} }