Added coding cost calculations to MV search

This commit is contained in:
Marko Viitanen 2014-03-07 15:05:39 +02:00
parent 9dde96f25e
commit c7e4861dbf
2 changed files with 167 additions and 75 deletions

View file

@ -52,9 +52,10 @@ typedef struct {
*/
typedef struct
{
uint32_t cost;
uint32_t bitcost;
int8_t mode;
int8_t mode_chroma;
uint32_t cost;
} cu_info_intra;
/**
@ -62,14 +63,14 @@ typedef struct
*/
typedef struct
{
int8_t mode;
uint32_t cost;
uint32_t bitcost;
int16_t mv[2];
int16_t mvd[2];
uint8_t mv_cand; // \brief selected MV candidate
uint8_t mv_ref; // \brief Index of the encoder_control.ref array.
uint8_t mv_dir; // \brief Probably describes if mv_ref is forward, backward or both. Might not be needed?
int8_t mode;
} cu_info_inter;
/**

View file

@ -71,26 +71,94 @@ const vector2d small_hexbs[5] = {
};
static int calc_mvd_cost(int x, int y, const vector2d *pred)
static uint32_t get_ep_ex_golomb_bitcost(uint32_t symbol, uint32_t count)
{
int32_t num_bins = 0;
while (symbol >= (uint32_t)(1 << count)) {
++num_bins;
symbol -= 1 << count;
++count;
}
num_bins += count+1;
return num_bins;
}
static uint32_t get_mvd_coding_cost(vector2d *mvd)
{
uint32_t bitcost = 0;
const int32_t mvd_hor = mvd->x;
const int32_t mvd_ver = mvd->y;
const int8_t hor_abs_gr0 = mvd_hor != 0;
const int8_t ver_abs_gr0 = mvd_ver != 0;
const uint32_t mvd_hor_abs = abs(mvd_hor);
const uint32_t mvd_ver_abs = abs(mvd_ver);
// Greater than 0 for x/y
bitcost += 2;
if (hor_abs_gr0) {
if (mvd_hor_abs > 1) {
bitcost += get_ep_ex_golomb_bitcost(mvd_hor_abs-2, 1);
}
bitcost += 2;
}
if (ver_abs_gr0) {
if (mvd_ver_abs > 1) {
bitcost += get_ep_ex_golomb_bitcost(mvd_ver_abs-2, 1);
}
bitcost += 2;
}
return bitcost;
}
static int calc_mvd_cost(int x, int y, const vector2d *pred,
int16_t mv_cand[2][2], int16_t merge_cand[MRG_MAX_NUM_CANDS][3],
int16_t num_cand,int32_t ref_idx)
{
int cost = 0;
// Get the absolute difference vector and count the bits.
x = abs(abs(x) - abs(pred->x));
y = abs(abs(y) - abs(pred->y));
while (x >>= 1) {
++cost;
}
while (y >>= 1) {
++cost;
uint32_t temp_bitcost = 0;
uint32_t merge_idx;
int cand1_cost,cand2_cost;
vector2d mvd_temp1, mvd_temp2;
int8_t merged = 0;
int8_t cur_mv_cand = 0;
x <<= 2;
y <<= 2;
// Check every candidate to find a match
for(merge_idx = 0; merge_idx < num_cand; merge_idx++) {
if (merge_cand[merge_idx][0] == x &&
merge_cand[merge_idx][1] == y &&
merge_cand[merge_idx][2] == ref_idx) {
temp_bitcost += merge_idx < MRG_MAX_NUM_CANDS ? merge_idx+1:merge_idx;
merged = 1;
break;
}
}
// I don't know what is a good cost function for this. It probably doesn't
// have to aproximate the actual cost of encoding the vector, but it's a
// place to start.
// Check mvd cost only if mv is not merged
if(!merged) {
mvd_temp1.x = x - mv_cand[0][0];
mvd_temp1.y = y - mv_cand[0][1];
cand1_cost = get_mvd_coding_cost(&mvd_temp1);
// Add two for quarter pixel resolution and multiply by two for Exp-Golomb.
return (cost ? (cost + 2) << 1 : 0);
mvd_temp2.x = x - mv_cand[1][0];
mvd_temp2.y = y - mv_cand[1][1];
cand2_cost = get_mvd_coding_cost(&mvd_temp2);
// Select candidate 1 if it has lower cost
if (cand2_cost < cand1_cost) {
cur_mv_cand = 1;
}
temp_bitcost += cur_mv_cand ? cand2_cost : cand1_cost;
}
return temp_bitcost*(int32_t)(g_cur_lambda_cost+0.5);
}
@ -116,7 +184,9 @@ static int calc_mvd_cost(int x, int y, const vector2d *pred)
*/
static unsigned hexagon_search(unsigned depth,
const picture *pic, const picture *ref,
const vector2d *orig, vector2d *mv_in_out)
const vector2d *orig, vector2d *mv_in_out,
int16_t mv_cand[2][2], int16_t merge_cand[MRG_MAX_NUM_CANDS][3],
int16_t num_cand, int32_t ref_idx)
{
vector2d mv = { mv_in_out->x >> 2, mv_in_out->y >> 2 };
int block_width = CU_WIDTH_FROM_DEPTH(depth);
@ -124,13 +194,14 @@ static unsigned hexagon_search(unsigned depth,
unsigned i;
unsigned best_index = 0; // Index of large_hexbs or finally small_hexbs.
// Search the initial 7 points of the hexagon.
for (i = 0; i < 7; ++i) {
const vector2d *pattern = &large_hexbs[i];
unsigned cost = calc_sad(pic, ref, orig->x, orig->y,
orig->x + mv.x + pattern->x, orig->y + mv.y + pattern->y,
block_width, block_width);
cost += calc_mvd_cost(mv.x + pattern->x, mv.y + pattern->y, mv_in_out);
cost += calc_mvd_cost(mv.x + pattern->x, mv.y + pattern->y, mv_in_out,mv_cand,merge_cand,num_cand,ref_idx);
if (cost < best_cost) {
best_cost = cost;
@ -143,7 +214,7 @@ static unsigned hexagon_search(unsigned depth,
unsigned cost = calc_sad(pic, ref, orig->x, orig->y,
orig->x, orig->y,
block_width, block_width);
cost += calc_mvd_cost(0, 0, mv_in_out);
cost += calc_mvd_cost(0, 0, mv_in_out,mv_cand,merge_cand,num_cand,ref_idx);
// If the 0,0 is better, redo the hexagon around that point.
if (cost < best_cost) {
@ -158,7 +229,7 @@ static unsigned hexagon_search(unsigned depth,
orig->x + pattern->x,
orig->y + pattern->y,
block_width, block_width);
cost += calc_mvd_cost(pattern->x, pattern->y, mv_in_out);
cost += calc_mvd_cost(pattern->x, pattern->y, mv_in_out,mv_cand,merge_cand,num_cand,ref_idx);
if (cost < best_cost) {
best_cost = cost;
@ -192,7 +263,7 @@ static unsigned hexagon_search(unsigned depth,
orig->x + mv.x + offset->x,
orig->y + mv.y + offset->y,
block_width, block_width);
cost += calc_mvd_cost(mv.x + offset->x, mv.y + offset->y, mv_in_out);
cost += calc_mvd_cost(mv.x + offset->x, mv.y + offset->y, mv_in_out,mv_cand,merge_cand,num_cand,ref_idx);
if (cost < best_cost) {
best_cost = cost;
@ -214,7 +285,7 @@ static unsigned hexagon_search(unsigned depth,
orig->x + mv.x + offset->x,
orig->y + mv.y + offset->y,
block_width, block_width);
cost += calc_mvd_cost(mv.x + offset->x, mv.y + offset->y, mv_in_out);
cost += calc_mvd_cost(mv.x + offset->x, mv.y + offset->y, mv_in_out,mv_cand,merge_cand,num_cand,ref_idx);
if (cost > 0 && cost < best_cost) {
best_cost = cost;
@ -237,7 +308,9 @@ static unsigned hexagon_search(unsigned depth,
#if SEARCH_MV_FULL_RADIUS
static unsigned search_mv_full(unsigned depth,
const picture *pic, const picture *ref,
const vector2d *orig, vector2d *mv_in_out)
const vector2d *orig, vector2d *mv_in_out,
int16_t mv_cand[2][2], int16_t merge_cand[MRG_MAX_NUM_CANDS][3],
int16_t num_cand, int32_t ref_idx)
{
vector2d mv = { mv_in_out->x >> 2, mv_in_out->y >> 2 };
int block_width = CU_WIDTH_FROM_DEPTH(depth);
@ -264,7 +337,7 @@ static unsigned search_mv_full(unsigned depth,
orig->x + x,
orig->y + y,
block_width, block_width);
cost += calc_mvd_cost(x, y, mv_in_out);
cost += calc_mvd_cost(x, y, mv_in_out,mv_cand,merge_cand,num_cand,ref_idx);
if (cost < best_cost) {
best_cost = cost;
mv.x = x;
@ -280,7 +353,6 @@ static unsigned search_mv_full(unsigned depth,
}
#endif
/**
* Update lcu to have best modes at this depth.
* \return Cost of best mode.
@ -296,18 +368,34 @@ static int search_cu_inter(encoder_control *encoder, int x, int y, int depth, lc
cu_info *cur_cu = &lcu->cu[cu_pos];
int16_t mv_cand[2][2];
// Search for merge mode candidate
int16_t merge_cand[MRG_MAX_NUM_CANDS][3];
// Get list of candidates
int16_t num_cand = inter_get_merge_cand(x, y, depth, merge_cand, cur_cu, lcu);
// Get MV candidates
inter_get_mv_cand(encoder, x, y, depth, mv_cand, cur_cu, lcu);
// Select better candidate
cur_cu->inter.mv_cand = 0; // Default to candidate 0
cur_cu->inter.cost = UINT_MAX;
for (ref_idx = 0; ref_idx < encoder->ref->used_size; ref_idx++) {
picture *ref_pic = encoder->ref->pics[ref_idx];
unsigned width_in_scu = NO_SCU_IN_LCU(ref_pic->width_in_lcu);
cu_info *ref_cu = &ref_pic->cu_array[MAX_DEPTH][y_cu * width_in_scu + x_cu];
uint32_t temp_cost = (int)(g_lambda_cost[encoder->QP] * ref_idx);
vector2d orig, mv;
uint32_t temp_bitcost = ref_idx;
uint32_t temp_cost = 0;
vector2d orig, mv, mvd;
int32_t merged = 0;
orig.x = x_cu * CU_MIN_SIZE_PIXELS;
orig.y = y_cu * CU_MIN_SIZE_PIXELS;
mv.x = 0;
mv.y = 0;
mvd.x = 0;
mvd.y = 0;
if (ref_cu->type == CU_INTER) {
mv.x = ref_cu->inter.mv[0];
mv.y = ref_cu->inter.mv[1];
@ -316,14 +404,57 @@ static int search_cu_inter(encoder_control *encoder, int x, int y, int depth, lc
#if SEARCH_MV_FULL_RADIUS
temp_cost += search_mv_full(depth, cur_pic, ref_pic, &orig, &mv);
#else
temp_cost += hexagon_search(depth, cur_pic, ref_pic, &orig, &mv);
temp_cost += hexagon_search(depth, cur_pic, ref_pic, &orig, &mv, mv_cand, merge_cand, num_cand, ref_idx);
#endif
merged = 0;
// Check every candidate to find a match
for(cur_cu->merge_idx = 0; cur_cu->merge_idx < num_cand; cur_cu->merge_idx++) {
if (merge_cand[cur_cu->merge_idx][0] == mv.x &&
merge_cand[cur_cu->merge_idx][1] == mv.y &&
merge_cand[cur_cu->merge_idx][2] == ref_idx) {
merged = 1;
//temp_bitcost += cur_cu->merge_idx < MRG_MAX_NUM_CANDS ? cur_cu->merge_idx+1:cur_cu->merge_idx;
break;
}
}
// Only check when candidates are different
if (!merged && (mv_cand[0][0] != mv_cand[1][0] || mv_cand[0][1] != mv_cand[1][1])) {
vector2d mvd_temp1, mvd_temp2;
int cand1_cost,cand2_cost;
mvd_temp1.x = mv.x - mv_cand[0][0];
mvd_temp1.y = mv.y - mv_cand[0][1];
cand1_cost = get_mvd_coding_cost(&mvd_temp1);
mvd_temp2.x = mv.x - mv_cand[1][0];
mvd_temp2.y = mv.y - mv_cand[1][1];
cand2_cost = get_mvd_coding_cost(&mvd_temp2);
// Select candidate 1 if it has lower cost
if (cand2_cost < cand1_cost) {
cur_cu->inter.mv_cand = 1;
}
}
mvd.x = mv.x - mv_cand[cur_cu->inter.mv_cand][0];
mvd.y = mv.y - mv_cand[cur_cu->inter.mv_cand][1];
//temp_bitcost += merged ? 0 : get_mvd_coding_cost(&mvd);
//temp_cost += temp_bitcost*g_cur_lambda_cost;
if(temp_cost < cur_cu->inter.cost) {
cur_cu->inter.mv_ref = ref_idx;
cur_cu->inter.mv_dir = 1;
cur_cu->inter.mv[0] = (int16_t)mv.x;
cur_cu->inter.mv[1] = (int16_t)mv.y;
cur_cu->inter.cost = temp_cost;
cur_cu->merged = merged;
cur_cu->inter.mv_ref = ref_idx;
cur_cu->inter.mv_dir = 1;
cur_cu->inter.mv[0] = (int16_t)mv.x;
cur_cu->inter.mv[1] = (int16_t)mv.y;
cur_cu->inter.mvd[0] = (int16_t)mvd.x;
cur_cu->inter.mvd[1] = (int16_t)mvd.y;
cur_cu->inter.cost = temp_cost;
cur_cu->inter.bitcost = temp_bitcost;
}
}
@ -680,46 +811,6 @@ static int search_cu(encoder_control *encoder, int x, int y, int depth, lcu_t wo
lcu_set_intra_mode(&work_tree[depth], x, y, depth, cur_cu->intra[0].mode, cur_cu->part_size);
intra_recon_lcu(encoder, x, y, depth,&work_tree[depth],encoder->in.cur_pic->width,encoder->in.cur_pic->height);
} else if (cur_cu->type == CU_INTER) {
int16_t mv_cand[2][2];
// Search for merge mode candidate
int16_t merge_cand[MRG_MAX_NUM_CANDS][3];
// Get list of candidates
int16_t num_cand = inter_get_merge_cand(x, y, depth, merge_cand, cur_cu, &work_tree[depth]);
// Check every candidate to find a match
for(cur_cu->merge_idx = 0; cur_cu->merge_idx < num_cand; cur_cu->merge_idx++) {
if(merge_cand[cur_cu->merge_idx][0] == cur_cu->inter.mv[0] &&
merge_cand[cur_cu->merge_idx][1] == cur_cu->inter.mv[1] &&
merge_cand[cur_cu->merge_idx][2] == cur_cu->inter.mv_ref) {
cur_cu->merged = 1;
break;
}
}
// Get MV candidates
inter_get_mv_cand(encoder, x, y, depth, mv_cand, cur_cu, &work_tree[depth]);
// Select better candidate
cur_cu->inter.mv_cand = 0; // Default to candidate 0
// Only check when candidates are different
if (mv_cand[0][0] != mv_cand[1][0] || mv_cand[0][1] != mv_cand[1][1]) {
// TODO: calculate bit costs
int cand_1_diff = abs(cur_cu->inter.mv[0] - mv_cand[0][0]) + abs(
cur_cu->inter.mv[1] - mv_cand[0][1]);
int cand_2_diff = abs(cur_cu->inter.mv[0] - mv_cand[1][0]) + abs(
cur_cu->inter.mv[1] - mv_cand[1][1]);
// Select candidate 1 if it's closer
if (cand_2_diff < cand_1_diff) {
cur_cu->inter.mv_cand = 1;
}
}
cur_cu->inter.mvd[0] = cur_cu->inter.mv[0] - mv_cand[cur_cu->inter.mv_cand][0];
cur_cu->inter.mvd[1] = cur_cu->inter.mv[1] - mv_cand[cur_cu->inter.mv_cand][1];
cur_cu->coded = 1;
inter_recon_lcu(encoder->ref->pics[cur_cu->inter.mv_ref], x, y, LCU_WIDTH>>depth, cur_cu->inter.mv, &work_tree[depth]);
encode_transform_tree(encoder, x, y, depth, &work_tree[depth]);
@ -731,7 +822,7 @@ static int search_cu(encoder_control *encoder, int x, int y, int depth, lcu_t wo
lcu_set_coeff(&work_tree[depth], x, y, depth, cur_cu);
}
}
//cost = lcu_get_final_cost(encoder, x, y, depth, cur_cu);
// Recursively split all the way to max search depth.
if (depth < MAX_INTRA_SEARCH_DEPTH || depth < MAX_INTER_SEARCH_DEPTH) {