Improve fast mvd coding cost estimation

A lot of time is being taken up by this function on ultrafast, and it
doesn't do a very good job. This change aims to both simplify the
logic and make the estimate better.

The logic is simplified by using a look up for the step mvd bit cost
step function instead of mimicking the binarization process. The
estimation is made better by checking fractional cabac bit costs.

The new function returns the same results as
kvz_get_mvd_coding_cost_cabac, but is also faster than the old
function.
This commit is contained in:
Ari Koivula 2016-08-29 23:51:20 +03:00
parent d31be8eb27
commit 82cfab58f8

View file

@ -100,17 +100,47 @@ static INLINE bool intmv_within_tile(const encoder_state_t *state, const vector2
} }
static uint32_t get_ep_ex_golomb_bitcost(uint32_t symbol, uint32_t count) static uint32_t get_ep_ex_golomb_bitcost(uint32_t symbol)
{ {
int32_t num_bins = 0; unsigned bins;
while (symbol >= (uint32_t)(1 << count)) {
++num_bins;
symbol -= 1 << count;
++count;
}
num_bins ++;
return num_bins; if (symbol < 2) {
bins = 2;
} else if (symbol < 6) {
bins = 4;
} else if (symbol < 14) {
bins = 6;
} else if (symbol < 30) {
bins = 8;
} else if (symbol < 62) {
bins = 10;
} else if (symbol < 126) {
bins = 12;
} else if (symbol < 254) {
bins = 14;
} else if (symbol < 510) {
bins = 16;
} else if (symbol < 1022) {
bins = 18;
} else if (symbol < 2046) {
bins = 20;
} else if (symbol < 4094) {
bins = 22;
} else if (symbol < 8190) {
bins = 24;
} else {
// Estimate bigger symbols with the current slope.
// (2 bits per 8192)
bins = 26 + 2 * (symbol - 8190) >> 13;
}
// TODO: It might be a good idea to put a small slope on this function to
// make sure any search function that follows the gradient heads towards
// a smaller MVD, but that would require fractinal costs and bits being
// used everywhere in inter search.
// return num_bins + 0.001 * symbol;
return bins;
} }
@ -177,34 +207,28 @@ static unsigned select_starting_point(int16_t num_cand, inter_merge_cand_t *merg
static uint32_t get_mvd_coding_cost(encoder_state_t * const state, vector2d_t *mvd, const cabac_data_t* cabac) static uint32_t get_mvd_coding_cost(encoder_state_t * const state, vector2d_t *mvd, const cabac_data_t* cabac)
{ {
uint32_t bitcost = 0; double bitcost = 0;
const int32_t mvd_hor = mvd->x; const vector2d_t abs_mvd = { abs(mvd->x), abs(mvd->y) };
const int32_t mvd_ver = mvd->y;
const int8_t hor_abs_gr0 = mvd_hor != 0;
const int8_t ver_abs_gr0 = mvd_ver != 0;
const uint32_t mvd_hor_abs = abs(mvd_hor);
const uint32_t mvd_ver_abs = abs(mvd_ver);
// Greater than 0 for x/y bitcost += CTX_ENTROPY_FBITS(&cabac->ctx.cu_mvd_model[0], abs_mvd.x > 0);
bitcost += 2; if (abs_mvd.x > 0) {
bitcost += CTX_ENTROPY_FBITS(&cabac->ctx.cu_mvd_model[1], abs_mvd.x > 1);
if (hor_abs_gr0) { if (abs_mvd.x > 1) {
if (mvd_hor_abs > 1) { bitcost += get_ep_ex_golomb_bitcost(abs_mvd.x - 2);
bitcost += get_ep_ex_golomb_bitcost(mvd_hor_abs-2, 1) - 2; // TODO: tune the costs
} }
// Greater than 1 + sign bitcost += 1; // sign
bitcost += 2;
} }
if (ver_abs_gr0) { bitcost += CTX_ENTROPY_FBITS(&cabac->ctx.cu_mvd_model[0], abs_mvd.y > 0);
if (mvd_ver_abs > 1) { if (abs_mvd.y > 0) {
bitcost += get_ep_ex_golomb_bitcost(mvd_ver_abs-2, 1) - 2; // TODO: tune the costs bitcost += CTX_ENTROPY_FBITS(&cabac->ctx.cu_mvd_model[1], abs_mvd.y > 1);
if (abs_mvd.y > 1) {
bitcost += get_ep_ex_golomb_bitcost(abs_mvd.y - 2);
} }
// Greater than 1 + sign bitcost += 1; // sign
bitcost += 2;
} }
return bitcost; return bitcost + 0.5;
} }
@ -238,11 +262,11 @@ static int calc_mvd_cost(encoder_state_t * const state, int x, int y, int mv_shi
if(!merged) { if(!merged) {
mvd_temp1.x = x - mv_cand[0][0]; mvd_temp1.x = x - mv_cand[0][0];
mvd_temp1.y = y - mv_cand[0][1]; mvd_temp1.y = y - mv_cand[0][1];
cand1_cost = get_mvd_coding_cost(state, &mvd_temp1, NULL); cand1_cost = get_mvd_coding_cost(state, &mvd_temp1, &state->cabac);
mvd_temp2.x = x - mv_cand[1][0]; mvd_temp2.x = x - mv_cand[1][0];
mvd_temp2.y = y - mv_cand[1][1]; mvd_temp2.y = y - mv_cand[1][1];
cand2_cost = get_mvd_coding_cost(state, &mvd_temp2, NULL); cand2_cost = get_mvd_coding_cost(state, &mvd_temp2, &state->cabac);
// Select candidate 1 if it has lower cost // Select candidate 1 if it has lower cost
if (cand2_cost < cand1_cost) { if (cand2_cost < cand1_cost) {