diff --git a/src/search.c b/src/search.c index 9073e2bf..fb1e49fd 100644 --- a/src/search.c +++ b/src/search.c @@ -1100,6 +1100,8 @@ static int search_cu_inter(const encoder_state_t * const state, int x, int y, in // Search bi-pred positions if (state->global->slicetype == SLICE_B) { + lcu_t *templcu = MALLOC(lcu_t, 1); + cost_pixel_nxn_func *satd = pixels_get_satd_func(LCU_WIDTH >> depth); #define NUM_PRIORITY_LIST 12; static const uint8_t priorityList0[] = { 0, 1, 0, 2, 1, 2, 0, 3, 1, 3, 2, 3 }; static const uint8_t priorityList1[] = { 1, 0, 2, 0, 2, 1, 3, 0, 3, 1, 3, 2 }; @@ -1113,53 +1115,82 @@ static int search_cu_inter(const encoder_state_t * const state, int x, int y, in if ((merge_cand[i].dir & 0x1) && (merge_cand[j].dir & 0x2)) { if (merge_cand[i].ref[0] != merge_cand[j].ref[1] || merge_cand[i].mv[0][0] != merge_cand[j].mv[1][0] || - merge_cand[i].mv[0][1] != merge_cand[j].mv[1][1]) { + merge_cand[i].mv[0][1] != merge_cand[j].mv[1][1]) { + uint32_t bitcost[2]; + uint32_t cost = 0; int8_t cu_mv_cand = 0; + int16_t mv[2][2]; + pixel_t tmp_block[64 * 64]; + pixel_t tmp_pic[64 * 64]; // Force L0 and L1 references if (state->global->refmap[merge_cand[i].ref[0]].list == 2 || state->global->refmap[merge_cand[j].ref[1]].list == 1) continue; - cur_cu->inter.mv_dir = 3; - cur_cu->inter.mv_ref_coded[0] = state->global->refmap[merge_cand[i].ref[0]].idx; - cur_cu->inter.mv_ref_coded[1] = state->global->refmap[merge_cand[j].ref[1]].idx; - cur_cu->merged = 0; - cur_cu->inter.mv_ref[0] = merge_cand[i].ref[0]; - cur_cu->inter.mv_ref[1] = merge_cand[j].ref[1]; - cur_cu->inter.mv[0][0] = merge_cand[i].mv[0][0] & 0xfff8; - cur_cu->inter.mv[0][1] = merge_cand[i].mv[0][1] & 0xfff8; - cur_cu->inter.mv[1][0] = merge_cand[j].mv[1][0] & 0xfff8; - cur_cu->inter.mv[1][1] = merge_cand[j].mv[1][1] & 0xfff8; + mv[0][0] = merge_cand[i].mv[0][0] & 0xfff8; + mv[0][1] = merge_cand[i].mv[0][1] & 0xfff8; + mv[1][0] = merge_cand[j].mv[1][0] & 0xfff8; + mv[1][1] = merge_cand[j].mv[1][1] & 0xfff8; + memset(templcu->rec.y, 0, 64 * 64); + inter_recon_lcu_bipred(state, state->global->ref->images[merge_cand[i].ref[0]], state->global->ref->images[merge_cand[j].ref[1]], x, y, LCU_WIDTH >> depth, mv, templcu); - for (int reflist = 0; reflist < 2; reflist++) { - cu_mv_cand = 0; - inter_get_mv_cand(state, x, y, depth, mv_cand, cur_cu, lcu, reflist); - if ((mv_cand[0][0] != mv_cand[1][0] || mv_cand[0][1] != mv_cand[1][1])) { - vector2d_t mvd_temp1, mvd_temp2; - int cand1_cost, cand2_cost; - - mvd_temp1.x = cur_cu->inter.mv[reflist][0] - mv_cand[0][0]; - mvd_temp1.y = cur_cu->inter.mv[reflist][1] - mv_cand[0][1]; - cand1_cost = get_mvd_coding_cost(&mvd_temp1); - - mvd_temp2.x = cur_cu->inter.mv[reflist][0] - mv_cand[1][0]; - mvd_temp2.y = cur_cu->inter.mv[reflist][1] - mv_cand[1][1]; - cand2_cost = get_mvd_coding_cost(&mvd_temp2); - - // Select candidate 1 if it has lower cost - if (cand2_cost < cand1_cost) { - //cu_mv_cand = 1; - cu_mv_cand = 0; - } + for (int ypos = 0; ypos < LCU_WIDTH >> depth; ++ypos) { + int dst_y = ypos*(LCU_WIDTH >> depth); + for (int xpos = 0; xpos < (LCU_WIDTH >> depth); ++xpos) { + tmp_block[dst_y + xpos] = templcu->rec.y[((y + ypos)&(LCU_WIDTH - 1))*LCU_WIDTH + ((x + xpos)&(LCU_WIDTH - 1))]; + tmp_pic[dst_y + xpos] = frame->source->y[x + xpos + (y + ypos)*frame->source->width]; } - cur_cu->inter.mvd[reflist][0] = cur_cu->inter.mv[reflist][0] - mv_cand[cu_mv_cand][0]; - cur_cu->inter.mvd[reflist][1] = cur_cu->inter.mv[reflist][1] - mv_cand[cu_mv_cand][1]; } - cur_cu->inter.cost = 0; - cur_cu->inter.bitcost = 10 + cur_cu->inter.mv_dir - 1 + cur_cu->inter.mv_ref_coded[0] + cur_cu->inter.mv_ref_coded[1]; - cur_cu->inter.mv_cand = cu_mv_cand; - break; + + cost = satd(tmp_pic, tmp_block); + + cost += calc_mvd_cost(state, merge_cand[i].mv[0][0] & 0xfff8, merge_cand[i].mv[0][1] & 0xfff8, 0, mv_cand, merge_cand, 0, ref_idx, &bitcost[0]); + cost += calc_mvd_cost(state, merge_cand[i].mv[1][0] & 0xfff8, merge_cand[i].mv[1][1] & 0xfff8, 0, mv_cand, merge_cand, 0, ref_idx, &bitcost[1]); + + if (cost < cur_cu->inter.cost) { + + cur_cu->inter.mv_dir = 3; + cur_cu->inter.mv_ref_coded[0] = state->global->refmap[merge_cand[i].ref[0]].idx; + cur_cu->inter.mv_ref_coded[1] = state->global->refmap[merge_cand[j].ref[1]].idx; + + cur_cu->merged = 0; + cur_cu->inter.mv_ref[0] = merge_cand[i].ref[0]; + cur_cu->inter.mv_ref[1] = merge_cand[j].ref[1]; + cur_cu->inter.mv[0][0] = merge_cand[i].mv[0][0] & 0xfff8; + cur_cu->inter.mv[0][1] = merge_cand[i].mv[0][1] & 0xfff8; + cur_cu->inter.mv[1][0] = merge_cand[j].mv[1][0] & 0xfff8; + cur_cu->inter.mv[1][1] = merge_cand[j].mv[1][1] & 0xfff8; + + for (int reflist = 0; reflist < 2; reflist++) { + cu_mv_cand = 0; + inter_get_mv_cand(state, x, y, depth, mv_cand, cur_cu, lcu, reflist); + if ((mv_cand[0][0] != mv_cand[1][0] || mv_cand[0][1] != mv_cand[1][1])) { + vector2d_t mvd_temp1, mvd_temp2; + int cand1_cost, cand2_cost; + + mvd_temp1.x = cur_cu->inter.mv[reflist][0] - mv_cand[0][0]; + mvd_temp1.y = cur_cu->inter.mv[reflist][1] - mv_cand[0][1]; + cand1_cost = get_mvd_coding_cost(&mvd_temp1); + + mvd_temp2.x = cur_cu->inter.mv[reflist][0] - mv_cand[1][0]; + mvd_temp2.y = cur_cu->inter.mv[reflist][1] - mv_cand[1][1]; + cand2_cost = get_mvd_coding_cost(&mvd_temp2); + + // Select candidate 1 if it has lower cost + if (cand2_cost < cand1_cost) { + //cu_mv_cand = 1; + cu_mv_cand = 0; + } + } + cur_cu->inter.mvd[reflist][0] = cur_cu->inter.mv[reflist][0] - mv_cand[cu_mv_cand][0]; + cur_cu->inter.mvd[reflist][1] = cur_cu->inter.mv[reflist][1] - mv_cand[cu_mv_cand][1]; + } + cur_cu->inter.cost = cost; + cur_cu->inter.bitcost = bitcost[0] + bitcost[1] + cur_cu->inter.mv_dir - 1 + cur_cu->inter.mv_ref_coded[0] + cur_cu->inter.mv_ref_coded[1]; + cur_cu->inter.mv_cand = cu_mv_cand; + } } } } + FREE_POINTER(templcu); } return cur_cu->inter.cost;