Bi-pred search now actually does cost calculations

This commit is contained in:
Marko Viitanen 2015-04-21 13:53:17 +03:00
parent e12ba7c80f
commit fb74f86a5b

View file

@ -1100,6 +1100,8 @@ static int search_cu_inter(const encoder_state_t * const state, int x, int y, in
// Search bi-pred positions
if (state->global->slicetype == SLICE_B) {
lcu_t *templcu = MALLOC(lcu_t, 1);
cost_pixel_nxn_func *satd = pixels_get_satd_func(LCU_WIDTH >> depth);
#define NUM_PRIORITY_LIST 12;
static const uint8_t priorityList0[] = { 0, 1, 0, 2, 1, 2, 0, 3, 1, 3, 2, 3 };
static const uint8_t priorityList1[] = { 1, 0, 2, 0, 2, 1, 3, 0, 3, 1, 3, 2 };
@ -1113,53 +1115,82 @@ static int search_cu_inter(const encoder_state_t * const state, int x, int y, in
if ((merge_cand[i].dir & 0x1) && (merge_cand[j].dir & 0x2)) {
if (merge_cand[i].ref[0] != merge_cand[j].ref[1] ||
merge_cand[i].mv[0][0] != merge_cand[j].mv[1][0] ||
merge_cand[i].mv[0][1] != merge_cand[j].mv[1][1]) {
merge_cand[i].mv[0][1] != merge_cand[j].mv[1][1]) {
uint32_t bitcost[2];
uint32_t cost = 0;
int8_t cu_mv_cand = 0;
int16_t mv[2][2];
pixel_t tmp_block[64 * 64];
pixel_t tmp_pic[64 * 64];
// Force L0 and L1 references
if (state->global->refmap[merge_cand[i].ref[0]].list == 2 || state->global->refmap[merge_cand[j].ref[1]].list == 1) continue;
cur_cu->inter.mv_dir = 3;
cur_cu->inter.mv_ref_coded[0] = state->global->refmap[merge_cand[i].ref[0]].idx;
cur_cu->inter.mv_ref_coded[1] = state->global->refmap[merge_cand[j].ref[1]].idx;
cur_cu->merged = 0;
cur_cu->inter.mv_ref[0] = merge_cand[i].ref[0];
cur_cu->inter.mv_ref[1] = merge_cand[j].ref[1];
cur_cu->inter.mv[0][0] = merge_cand[i].mv[0][0] & 0xfff8;
cur_cu->inter.mv[0][1] = merge_cand[i].mv[0][1] & 0xfff8;
cur_cu->inter.mv[1][0] = merge_cand[j].mv[1][0] & 0xfff8;
cur_cu->inter.mv[1][1] = merge_cand[j].mv[1][1] & 0xfff8;
mv[0][0] = merge_cand[i].mv[0][0] & 0xfff8;
mv[0][1] = merge_cand[i].mv[0][1] & 0xfff8;
mv[1][0] = merge_cand[j].mv[1][0] & 0xfff8;
mv[1][1] = merge_cand[j].mv[1][1] & 0xfff8;
memset(templcu->rec.y, 0, 64 * 64);
inter_recon_lcu_bipred(state, state->global->ref->images[merge_cand[i].ref[0]], state->global->ref->images[merge_cand[j].ref[1]], x, y, LCU_WIDTH >> depth, mv, templcu);
for (int reflist = 0; reflist < 2; reflist++) {
cu_mv_cand = 0;
inter_get_mv_cand(state, x, y, depth, mv_cand, cur_cu, lcu, reflist);
if ((mv_cand[0][0] != mv_cand[1][0] || mv_cand[0][1] != mv_cand[1][1])) {
vector2d_t mvd_temp1, mvd_temp2;
int cand1_cost, cand2_cost;
mvd_temp1.x = cur_cu->inter.mv[reflist][0] - mv_cand[0][0];
mvd_temp1.y = cur_cu->inter.mv[reflist][1] - mv_cand[0][1];
cand1_cost = get_mvd_coding_cost(&mvd_temp1);
mvd_temp2.x = cur_cu->inter.mv[reflist][0] - mv_cand[1][0];
mvd_temp2.y = cur_cu->inter.mv[reflist][1] - mv_cand[1][1];
cand2_cost = get_mvd_coding_cost(&mvd_temp2);
// Select candidate 1 if it has lower cost
if (cand2_cost < cand1_cost) {
//cu_mv_cand = 1;
cu_mv_cand = 0;
}
for (int ypos = 0; ypos < LCU_WIDTH >> depth; ++ypos) {
int dst_y = ypos*(LCU_WIDTH >> depth);
for (int xpos = 0; xpos < (LCU_WIDTH >> depth); ++xpos) {
tmp_block[dst_y + xpos] = templcu->rec.y[((y + ypos)&(LCU_WIDTH - 1))*LCU_WIDTH + ((x + xpos)&(LCU_WIDTH - 1))];
tmp_pic[dst_y + xpos] = frame->source->y[x + xpos + (y + ypos)*frame->source->width];
}
cur_cu->inter.mvd[reflist][0] = cur_cu->inter.mv[reflist][0] - mv_cand[cu_mv_cand][0];
cur_cu->inter.mvd[reflist][1] = cur_cu->inter.mv[reflist][1] - mv_cand[cu_mv_cand][1];
}
cur_cu->inter.cost = 0;
cur_cu->inter.bitcost = 10 + cur_cu->inter.mv_dir - 1 + cur_cu->inter.mv_ref_coded[0] + cur_cu->inter.mv_ref_coded[1];
cur_cu->inter.mv_cand = cu_mv_cand;
break;
cost = satd(tmp_pic, tmp_block);
cost += calc_mvd_cost(state, merge_cand[i].mv[0][0] & 0xfff8, merge_cand[i].mv[0][1] & 0xfff8, 0, mv_cand, merge_cand, 0, ref_idx, &bitcost[0]);
cost += calc_mvd_cost(state, merge_cand[i].mv[1][0] & 0xfff8, merge_cand[i].mv[1][1] & 0xfff8, 0, mv_cand, merge_cand, 0, ref_idx, &bitcost[1]);
if (cost < cur_cu->inter.cost) {
cur_cu->inter.mv_dir = 3;
cur_cu->inter.mv_ref_coded[0] = state->global->refmap[merge_cand[i].ref[0]].idx;
cur_cu->inter.mv_ref_coded[1] = state->global->refmap[merge_cand[j].ref[1]].idx;
cur_cu->merged = 0;
cur_cu->inter.mv_ref[0] = merge_cand[i].ref[0];
cur_cu->inter.mv_ref[1] = merge_cand[j].ref[1];
cur_cu->inter.mv[0][0] = merge_cand[i].mv[0][0] & 0xfff8;
cur_cu->inter.mv[0][1] = merge_cand[i].mv[0][1] & 0xfff8;
cur_cu->inter.mv[1][0] = merge_cand[j].mv[1][0] & 0xfff8;
cur_cu->inter.mv[1][1] = merge_cand[j].mv[1][1] & 0xfff8;
for (int reflist = 0; reflist < 2; reflist++) {
cu_mv_cand = 0;
inter_get_mv_cand(state, x, y, depth, mv_cand, cur_cu, lcu, reflist);
if ((mv_cand[0][0] != mv_cand[1][0] || mv_cand[0][1] != mv_cand[1][1])) {
vector2d_t mvd_temp1, mvd_temp2;
int cand1_cost, cand2_cost;
mvd_temp1.x = cur_cu->inter.mv[reflist][0] - mv_cand[0][0];
mvd_temp1.y = cur_cu->inter.mv[reflist][1] - mv_cand[0][1];
cand1_cost = get_mvd_coding_cost(&mvd_temp1);
mvd_temp2.x = cur_cu->inter.mv[reflist][0] - mv_cand[1][0];
mvd_temp2.y = cur_cu->inter.mv[reflist][1] - mv_cand[1][1];
cand2_cost = get_mvd_coding_cost(&mvd_temp2);
// Select candidate 1 if it has lower cost
if (cand2_cost < cand1_cost) {
//cu_mv_cand = 1;
cu_mv_cand = 0;
}
}
cur_cu->inter.mvd[reflist][0] = cur_cu->inter.mv[reflist][0] - mv_cand[cu_mv_cand][0];
cur_cu->inter.mvd[reflist][1] = cur_cu->inter.mv[reflist][1] - mv_cand[cu_mv_cand][1];
}
cur_cu->inter.cost = cost;
cur_cu->inter.bitcost = bitcost[0] + bitcost[1] + cur_cu->inter.mv_dir - 1 + cur_cu->inter.mv_ref_coded[0] + cur_cu->inter.mv_ref_coded[1];
cur_cu->inter.mv_cand = cu_mv_cand;
}
}
}
}
FREE_POINTER(templcu);
}
return cur_cu->inter.cost;