diff --git a/src/rdo.c b/src/rdo.c index 31e1fdd8..2dd1dd2e 100644 --- a/src/rdo.c +++ b/src/rdo.c @@ -852,17 +852,15 @@ void kvz_rdoq(encoder_state_t * const state, coeff_t *coef, coeff_t *dest_coeff, } } -/** MVD cost calculation with CABAC -* \returns int -* Calculates cost of actual motion vectors using CABAC coding -*/ +/** + * Calculate cost of actual motion vectors using CABAC coding + */ uint32_t kvz_get_mvd_coding_cost_cabac(const encoder_state_t *state, - vector2d_t *mvd, - const cabac_data_t* real_cabac) + const cabac_data_t* real_cabac, + int32_t mvd_hor, + int32_t mvd_ver) { uint32_t bitcost = 0; - const int32_t mvd_hor = mvd->x; - const int32_t mvd_ver = mvd->y; const int8_t hor_abs_gr0 = mvd_hor != 0; const int8_t ver_abs_gr0 = mvd_ver != 0; const uint32_t mvd_hor_abs = abs(mvd_hor); @@ -919,8 +917,7 @@ uint32_t kvz_calc_mvd_cost_cabac(const encoder_state_t * state, cabac_data_t state_cabac_copy; cabac_data_t* cabac; uint32_t merge_idx; - int cand1_cost, cand2_cost; - vector2d_t mvd_temp1, mvd_temp2, mvd = { 0, 0 }; + vector2d_t mvd = { 0, 0 }; int8_t merged = 0; int8_t cur_mv_cand = 0; @@ -952,20 +949,23 @@ uint32_t kvz_calc_mvd_cost_cabac(const encoder_state_t * state, cabac = &state_cabac_copy; if (!merged) { - mvd_temp1.x = x - mv_cand[0][0]; - mvd_temp1.y = y - mv_cand[0][1]; - cand1_cost = kvz_get_mvd_coding_cost_cabac(state, &mvd_temp1, cabac); - - mvd_temp2.x = x - mv_cand[1][0]; - mvd_temp2.y = y - mv_cand[1][1]; - cand2_cost = kvz_get_mvd_coding_cost_cabac(state, &mvd_temp2, cabac); + vector2d_t mvd1 = { + x - mv_cand[0][0], + y - mv_cand[0][1], + }; + vector2d_t mvd2 = { + x - mv_cand[1][0], + y - mv_cand[1][1], + }; + uint32_t cand1_cost = kvz_get_mvd_coding_cost_cabac(state, cabac, mvd1.x, mvd1.y); + uint32_t cand2_cost = kvz_get_mvd_coding_cost_cabac(state, cabac, mvd2.x, mvd2.y); // Select candidate 1 if it has lower cost if (cand2_cost < cand1_cost) { cur_mv_cand = 1; - mvd = mvd_temp2; + mvd = mvd2; } else { - mvd = mvd_temp1; + mvd = mvd1; } } diff --git a/src/rdo.h b/src/rdo.h index 8a8e0022..72450fb7 100644 --- a/src/rdo.h +++ b/src/rdo.h @@ -58,8 +58,9 @@ uint32_t kvz_get_coded_level(encoder_state_t * state, double* coded_cost, double kvz_mvd_cost_func kvz_calc_mvd_cost_cabac; uint32_t kvz_get_mvd_coding_cost_cabac(const encoder_state_t *state, - vector2d_t *mvd, - const cabac_data_t* cabac); + const cabac_data_t* cabac, + int32_t mvd_hor, + int32_t mvd_ver); // Number of fixed point fractional bits used in the fractional bit table. #define CTX_FRAC_BITS 15 diff --git a/src/search_inter.c b/src/search_inter.c index f0aa11ef..306f89e1 100644 --- a/src/search_inter.c +++ b/src/search_inter.c @@ -307,11 +307,12 @@ static void select_starting_point(inter_search_info_t *info, vector2d_t extra_mv static uint32_t get_mvd_coding_cost(const encoder_state_t *state, - vector2d_t *mvd, - const cabac_data_t* cabac) + const cabac_data_t* cabac, + const int32_t mvd_hor, + const int32_t mvd_ver) { unsigned bitcost = 0; - const vector2d_t abs_mvd = { abs(mvd->x), abs(mvd->y) }; + const vector2d_t abs_mvd = { abs(mvd_hor), abs(mvd_ver) }; bitcost += CTX_ENTROPY_BITS(&cabac->ctx.cu_mvd_model[0], abs_mvd.x > 0); if (abs_mvd.x > 0) { @@ -336,6 +337,53 @@ static uint32_t get_mvd_coding_cost(const encoder_state_t *state, } +static int select_mv_cand(const encoder_state_t *state, + int16_t mv_cand[2][2], + int32_t mv_x, + int32_t mv_y, + uint32_t *cost_out) +{ + const bool same_cand = + (mv_cand[0][0] == mv_cand[1][0] && mv_cand[0][1] == mv_cand[1][1]); + + if (same_cand && !cost_out) { + // Pick the first one if both candidates are the same. + return 0; + } + + uint32_t (*mvd_coding_cost)(const encoder_state_t * const state, + const cabac_data_t*, + int32_t, int32_t); + if (state->encoder_control->cfg.mv_rdo) { + mvd_coding_cost = kvz_get_mvd_coding_cost_cabac; + } else { + mvd_coding_cost = get_mvd_coding_cost; + } + + uint32_t cand1_cost = mvd_coding_cost( + state, &state->cabac, + mv_x - mv_cand[0][0], + mv_y - mv_cand[0][1]); + + uint32_t cand2_cost; + if (same_cand) { + cand2_cost = cand1_cost; + } else { + cand2_cost = mvd_coding_cost( + state, &state->cabac, + mv_x - mv_cand[1][0], + mv_y - mv_cand[1][1]); + } + + if (cost_out) { + *cost_out = MIN(cand1_cost, cand2_cost); + } + + // Pick the second candidate if it has lower cost. + return cand2_cost < cand1_cost ? 1 : 0; +} + + static uint32_t calc_mvd_cost(const encoder_state_t *state, int x, int y, @@ -348,10 +396,7 @@ static uint32_t calc_mvd_cost(const encoder_state_t *state, { uint32_t temp_bitcost = 0; uint32_t merge_idx; - int cand1_cost,cand2_cost; - vector2d_t mvd_temp1, mvd_temp2; int8_t merged = 0; - int8_t cur_mv_cand = 0; x *= 1 << mv_shift; y *= 1 << mv_shift; @@ -371,20 +416,10 @@ static uint32_t calc_mvd_cost(const encoder_state_t *state, } // Check mvd cost only if mv is not merged - if(!merged) { - mvd_temp1.x = x - mv_cand[0][0]; - mvd_temp1.y = y - mv_cand[0][1]; - cand1_cost = get_mvd_coding_cost(state, &mvd_temp1, &state->cabac); - - mvd_temp2.x = x - mv_cand[1][0]; - mvd_temp2.y = y - mv_cand[1][1]; - cand2_cost = get_mvd_coding_cost(state, &mvd_temp2, &state->cabac); - - // Select candidate 1 if it has lower cost - if (cand2_cost < cand1_cost) { - cur_mv_cand = 1; - } - temp_bitcost += cur_mv_cand ? cand2_cost : cand1_cost; + if (!merged) { + uint32_t mvd_cost = 0; + select_mv_cand(state, mv_cand, x, y, &mvd_cost); + temp_bitcost += mvd_cost; } *bitcost = temp_bitcost; return temp_bitcost*(int32_t)(state->lambda_sqrt + 0.5); @@ -1315,30 +1350,9 @@ static void search_pu_inter_ref(inter_search_info_t *info, // Only check when candidates are different int cu_mv_cand = 0; - if (!merged && ( - info->mv_cand[0][0] != info->mv_cand[1][0] || - info->mv_cand[0][1] != info->mv_cand[1][1])) - { - uint32_t (*mvd_coding_cost)(const encoder_state_t * const state, - vector2d_t *, - const cabac_data_t*) = - cfg->mv_rdo ? kvz_get_mvd_coding_cost_cabac : get_mvd_coding_cost; - - vector2d_t mvd_temp1, mvd_temp2; - int cand1_cost,cand2_cost; - - mvd_temp1.x = mv.x - info->mv_cand[0][0]; - mvd_temp1.y = mv.y - info->mv_cand[0][1]; - cand1_cost = mvd_coding_cost(info->state, &mvd_temp1, &info->state->cabac); - - mvd_temp2.x = mv.x - info->mv_cand[1][0]; - mvd_temp2.y = mv.y - info->mv_cand[1][1]; - cand2_cost = mvd_coding_cost(info->state, &mvd_temp2, &info->state->cabac); - - // Select candidate 1 if it has lower cost - if (cand2_cost < cand1_cost) { - cu_mv_cand = 1; - } + if (!merged) { + cu_mv_cand = + select_mv_cand(info->state, info->mv_cand, mv.x, mv.y, NULL); } if (info->best_cost < *inter_cost) { @@ -1465,7 +1479,6 @@ static void search_pu_inter(encoder_state_t * const state, { uint32_t bitcost[2]; uint32_t cost = 0; - int8_t cu_mv_cand = 0; int16_t mv[2][2]; kvz_pixel tmp_block[64 * 64]; kvz_pixel tmp_pic[64 * 64]; @@ -1558,32 +1571,13 @@ static void search_pu_inter(encoder_state_t * const state, // Each motion vector has its own candidate for (int reflist = 0; reflist < 2; reflist++) { - cu_mv_cand = 0; kvz_inter_get_mv_cand(state, x, y, width, height, info.mv_cand, cur_cu, lcu, reflist); - if (info.mv_cand[0][0] != info.mv_cand[1][0] || - info.mv_cand[0][1] != info.mv_cand[1][1]) - { - uint32_t (*mvd_coding_cost)(const encoder_state_t * const state, - vector2d_t *, - const cabac_data_t*) = - cfg->mv_rdo ? kvz_get_mvd_coding_cost_cabac : get_mvd_coding_cost; - - vector2d_t mvd_temp1, mvd_temp2; - int cand1_cost, cand2_cost; - - mvd_temp1.x = cur_cu->inter.mv[reflist][0] - info.mv_cand[0][0]; - mvd_temp1.y = cur_cu->inter.mv[reflist][1] - info.mv_cand[0][1]; - cand1_cost = mvd_coding_cost(state, &mvd_temp1, (cabac_data_t*)&state->cabac); - - mvd_temp2.x = cur_cu->inter.mv[reflist][0] - info.mv_cand[1][0]; - mvd_temp2.y = cur_cu->inter.mv[reflist][1] - info.mv_cand[1][1]; - cand2_cost = mvd_coding_cost(state, &mvd_temp2, (cabac_data_t*)&state->cabac); - - // Select candidate 1 if it has lower cost - if (cand2_cost < cand1_cost) { - cu_mv_cand = 1; - } - } + int cu_mv_cand = select_mv_cand( + state, + info.mv_cand, + cur_cu->inter.mv[reflist][0], + cur_cu->inter.mv[reflist][1], + NULL); CU_SET_MV_CAND(cur_cu, reflist, cu_mv_cand); }