From 82cfab58f82c173d56a27dc49eeb82bb76f549be Mon Sep 17 00:00:00 2001 From: Ari Koivula Date: Mon, 29 Aug 2016 23:51:20 +0300 Subject: [PATCH] Improve fast mvd coding cost estimation A lot of time is being taken up by this function on ultrafast, and it doesn't do a very good job. This change aims to both simplify the logic and make the estimate better. The logic is simplified by using a look up for the step mvd bit cost step function instead of mimicking the binarization process. The estimation is made better by checking fractional cabac bit costs. The new function returns the same results as kvz_get_mvd_coding_cost_cabac, but is also faster than the old function. --- src/search_inter.c | 88 +++++++++++++++++++++++++++++----------------- 1 file changed, 56 insertions(+), 32 deletions(-) diff --git a/src/search_inter.c b/src/search_inter.c index 51c482d7..f04de916 100644 --- a/src/search_inter.c +++ b/src/search_inter.c @@ -100,17 +100,47 @@ static INLINE bool intmv_within_tile(const encoder_state_t *state, const vector2 } -static uint32_t get_ep_ex_golomb_bitcost(uint32_t symbol, uint32_t count) +static uint32_t get_ep_ex_golomb_bitcost(uint32_t symbol) { - int32_t num_bins = 0; - while (symbol >= (uint32_t)(1 << count)) { - ++num_bins; - symbol -= 1 << count; - ++count; - } - num_bins ++; + unsigned bins; - return num_bins; + if (symbol < 2) { + bins = 2; + } else if (symbol < 6) { + bins = 4; + } else if (symbol < 14) { + bins = 6; + } else if (symbol < 30) { + bins = 8; + } else if (symbol < 62) { + bins = 10; + } else if (symbol < 126) { + bins = 12; + } else if (symbol < 254) { + bins = 14; + } else if (symbol < 510) { + bins = 16; + } else if (symbol < 1022) { + bins = 18; + } else if (symbol < 2046) { + bins = 20; + } else if (symbol < 4094) { + bins = 22; + } else if (symbol < 8190) { + bins = 24; + } else { + // Estimate bigger symbols with the current slope. + // (2 bits per 8192) + bins = 26 + 2 * (symbol - 8190) >> 13; + } + + // TODO: It might be a good idea to put a small slope on this function to + // make sure any search function that follows the gradient heads towards + // a smaller MVD, but that would require fractinal costs and bits being + // used everywhere in inter search. + // return num_bins + 0.001 * symbol; + + return bins; } @@ -177,34 +207,28 @@ static unsigned select_starting_point(int16_t num_cand, inter_merge_cand_t *merg static uint32_t get_mvd_coding_cost(encoder_state_t * const state, vector2d_t *mvd, const cabac_data_t* cabac) { - uint32_t bitcost = 0; - const int32_t mvd_hor = mvd->x; - const int32_t mvd_ver = mvd->y; - const int8_t hor_abs_gr0 = mvd_hor != 0; - const int8_t ver_abs_gr0 = mvd_ver != 0; - const uint32_t mvd_hor_abs = abs(mvd_hor); - const uint32_t mvd_ver_abs = abs(mvd_ver); + double bitcost = 0; + const vector2d_t abs_mvd = { abs(mvd->x), abs(mvd->y) }; - // Greater than 0 for x/y - bitcost += 2; - - if (hor_abs_gr0) { - if (mvd_hor_abs > 1) { - bitcost += get_ep_ex_golomb_bitcost(mvd_hor_abs-2, 1) - 2; // TODO: tune the costs + bitcost += CTX_ENTROPY_FBITS(&cabac->ctx.cu_mvd_model[0], abs_mvd.x > 0); + if (abs_mvd.x > 0) { + bitcost += CTX_ENTROPY_FBITS(&cabac->ctx.cu_mvd_model[1], abs_mvd.x > 1); + if (abs_mvd.x > 1) { + bitcost += get_ep_ex_golomb_bitcost(abs_mvd.x - 2); } - // Greater than 1 + sign - bitcost += 2; + bitcost += 1; // sign } - if (ver_abs_gr0) { - if (mvd_ver_abs > 1) { - bitcost += get_ep_ex_golomb_bitcost(mvd_ver_abs-2, 1) - 2; // TODO: tune the costs + bitcost += CTX_ENTROPY_FBITS(&cabac->ctx.cu_mvd_model[0], abs_mvd.y > 0); + if (abs_mvd.y > 0) { + bitcost += CTX_ENTROPY_FBITS(&cabac->ctx.cu_mvd_model[1], abs_mvd.y > 1); + if (abs_mvd.y > 1) { + bitcost += get_ep_ex_golomb_bitcost(abs_mvd.y - 2); } - // Greater than 1 + sign - bitcost += 2; + bitcost += 1; // sign } - return bitcost; + return bitcost + 0.5; } @@ -238,11 +262,11 @@ static int calc_mvd_cost(encoder_state_t * const state, int x, int y, int mv_shi if(!merged) { mvd_temp1.x = x - mv_cand[0][0]; mvd_temp1.y = y - mv_cand[0][1]; - cand1_cost = get_mvd_coding_cost(state, &mvd_temp1, NULL); + cand1_cost = get_mvd_coding_cost(state, &mvd_temp1, &state->cabac); mvd_temp2.x = x - mv_cand[1][0]; mvd_temp2.y = y - mv_cand[1][1]; - cand2_cost = get_mvd_coding_cost(state, &mvd_temp2, NULL); + cand2_cost = get_mvd_coding_cost(state, &mvd_temp2, &state->cabac); // Select candidate 1 if it has lower cost if (cand2_cost < cand1_cost) {