diff --git a/src/search_inter.c b/src/search_inter.c index 51c482d7..f04de916 100644 --- a/src/search_inter.c +++ b/src/search_inter.c @@ -100,17 +100,47 @@ static INLINE bool intmv_within_tile(const encoder_state_t *state, const vector2 } -static uint32_t get_ep_ex_golomb_bitcost(uint32_t symbol, uint32_t count) +static uint32_t get_ep_ex_golomb_bitcost(uint32_t symbol) { - int32_t num_bins = 0; - while (symbol >= (uint32_t)(1 << count)) { - ++num_bins; - symbol -= 1 << count; - ++count; - } - num_bins ++; + unsigned bins; - return num_bins; + if (symbol < 2) { + bins = 2; + } else if (symbol < 6) { + bins = 4; + } else if (symbol < 14) { + bins = 6; + } else if (symbol < 30) { + bins = 8; + } else if (symbol < 62) { + bins = 10; + } else if (symbol < 126) { + bins = 12; + } else if (symbol < 254) { + bins = 14; + } else if (symbol < 510) { + bins = 16; + } else if (symbol < 1022) { + bins = 18; + } else if (symbol < 2046) { + bins = 20; + } else if (symbol < 4094) { + bins = 22; + } else if (symbol < 8190) { + bins = 24; + } else { + // Estimate bigger symbols with the current slope. + // (2 bits per 8192) + bins = 26 + 2 * (symbol - 8190) >> 13; + } + + // TODO: It might be a good idea to put a small slope on this function to + // make sure any search function that follows the gradient heads towards + // a smaller MVD, but that would require fractinal costs and bits being + // used everywhere in inter search. + // return num_bins + 0.001 * symbol; + + return bins; } @@ -177,34 +207,28 @@ static unsigned select_starting_point(int16_t num_cand, inter_merge_cand_t *merg static uint32_t get_mvd_coding_cost(encoder_state_t * const state, vector2d_t *mvd, const cabac_data_t* cabac) { - uint32_t bitcost = 0; - const int32_t mvd_hor = mvd->x; - const int32_t mvd_ver = mvd->y; - const int8_t hor_abs_gr0 = mvd_hor != 0; - const int8_t ver_abs_gr0 = mvd_ver != 0; - const uint32_t mvd_hor_abs = abs(mvd_hor); - const uint32_t mvd_ver_abs = abs(mvd_ver); + double bitcost = 0; + const vector2d_t abs_mvd = { abs(mvd->x), abs(mvd->y) }; - // Greater than 0 for x/y - bitcost += 2; - - if (hor_abs_gr0) { - if (mvd_hor_abs > 1) { - bitcost += get_ep_ex_golomb_bitcost(mvd_hor_abs-2, 1) - 2; // TODO: tune the costs + bitcost += CTX_ENTROPY_FBITS(&cabac->ctx.cu_mvd_model[0], abs_mvd.x > 0); + if (abs_mvd.x > 0) { + bitcost += CTX_ENTROPY_FBITS(&cabac->ctx.cu_mvd_model[1], abs_mvd.x > 1); + if (abs_mvd.x > 1) { + bitcost += get_ep_ex_golomb_bitcost(abs_mvd.x - 2); } - // Greater than 1 + sign - bitcost += 2; + bitcost += 1; // sign } - if (ver_abs_gr0) { - if (mvd_ver_abs > 1) { - bitcost += get_ep_ex_golomb_bitcost(mvd_ver_abs-2, 1) - 2; // TODO: tune the costs + bitcost += CTX_ENTROPY_FBITS(&cabac->ctx.cu_mvd_model[0], abs_mvd.y > 0); + if (abs_mvd.y > 0) { + bitcost += CTX_ENTROPY_FBITS(&cabac->ctx.cu_mvd_model[1], abs_mvd.y > 1); + if (abs_mvd.y > 1) { + bitcost += get_ep_ex_golomb_bitcost(abs_mvd.y - 2); } - // Greater than 1 + sign - bitcost += 2; + bitcost += 1; // sign } - return bitcost; + return bitcost + 0.5; } @@ -238,11 +262,11 @@ static int calc_mvd_cost(encoder_state_t * const state, int x, int y, int mv_shi if(!merged) { mvd_temp1.x = x - mv_cand[0][0]; mvd_temp1.y = y - mv_cand[0][1]; - cand1_cost = get_mvd_coding_cost(state, &mvd_temp1, NULL); + cand1_cost = get_mvd_coding_cost(state, &mvd_temp1, &state->cabac); mvd_temp2.x = x - mv_cand[1][0]; mvd_temp2.y = y - mv_cand[1][1]; - cand2_cost = get_mvd_coding_cost(state, &mvd_temp2, NULL); + cand2_cost = get_mvd_coding_cost(state, &mvd_temp2, &state->cabac); // Select candidate 1 if it has lower cost if (cand2_cost < cand1_cost) {