[inter] Implement pairwise-average candidates for merge candidates

- Half-pel candidates are skipped for now because it needs some special handling
This commit is contained in:
Marko Viitanen 2021-11-01 13:24:23 +02:00
parent 4a7e4e3e20
commit 30d97d9af6
3 changed files with 87 additions and 6 deletions

View file

@ -1460,6 +1460,12 @@ void kvz_hmvp_add_mv(const encoder_state_t* const state, uint32_t pic_x, uint32_
}
}
static void round_avg_mv(int16_t* mvx, int16_t* mvy, int nShift)
{
const int nOffset = 1 << (nShift - 1);
*mvx = (*mvx + nOffset - (*mvx >= 0)) >> nShift;
*mvy = (*mvy + nOffset - (*mvy >= 0)) >> nShift;
}
/**
* \brief Get merge predictions for current block
@ -1591,7 +1597,9 @@ uint8_t kvz_inter_get_merge_cand(const encoder_state_t * const state,
}
}
if (candidates != max_num_cands - 1) {
if (candidates == max_num_cands) return candidates;
if (0 && candidates != max_num_cands - 1) {
const uint32_t ctu_row = (y >> LOG2_LCU_WIDTH);
const uint32_t ctu_row_mul_five = ctu_row * MAX_NUM_HMVP_CANDS;
int32_t num_cand = state->frame->hmvp_size[ctu_row];
@ -1611,11 +1619,82 @@ uint8_t kvz_inter_get_merge_cand(const encoder_state_t * const state,
mv_cand[candidates].ref[1] = state->frame->hmvp_lut[ctu_row_mul_five + i].inter.mv_ref[1];
}
candidates++;
if (candidates == max_num_cands) return candidates;
if (candidates == max_num_cands - 1) break;
}
}
}
// pairwise-average candidates
if (candidates > 1 && candidates < max_num_cands)
{
// calculate average MV for L0 and L1 seperately
uint8_t inter_dir = 0;
for (int reflist = 0; reflist < (state->frame->slicetype == KVZ_SLICE_B ? 2 : 1); reflist++)
{
const int16_t ref_i = mv_cand[0].dir & (reflist + 1) ? mv_cand[0].ref[reflist] : -1;
const int16_t ref_j = mv_cand[1].dir & (reflist + 1) ? mv_cand[1].ref[reflist] : -1;
// both MVs are invalid, skip
if ((ref_i == -1) && (ref_j == -1))
{
continue;
}
inter_dir += 1 << reflist;
// both MVs are valid, average these two MVs
if ((ref_i != -1) && (ref_j != -1))
{
int16_t mv_i[2] = { mv_cand[0].mv[reflist][0], mv_cand[0].mv[reflist][1] };
int16_t mv_j[2] = { mv_cand[1].mv[reflist][0], mv_cand[1].mv[reflist][1] };
// average two MVs
int16_t avg_mv[2] = { (mv_i[0] + mv_j[0]) * 2, (mv_i[1] + mv_j[1]) * 2 };
round_avg_mv(&avg_mv[0], &avg_mv[1], 1);
if (avg_mv[0] & 1 || avg_mv[1] & 1) {
mv_cand[candidates].half_pel = true;
} else {
avg_mv[0] = avg_mv[0] >> 1;
avg_mv[1] = avg_mv[1] >> 1;
}
mv_cand[candidates].mv[reflist][0] = avg_mv[0];
mv_cand[candidates].mv[reflist][1] = avg_mv[1];
mv_cand[candidates].ref[reflist] = ref_i;
}
// only one MV is valid, take the only one MV
else if (ref_i != -1)
{
int16_t mv_i[2] = { mv_cand[0].mv[reflist][0], mv_cand[0].mv[reflist][1] };
mv_cand[candidates].mv[reflist][0] = mv_i[0];
mv_cand[candidates].mv[reflist][1] = mv_i[1];
mv_cand[candidates].ref[reflist] = ref_i;
}
else if (ref_j != -1)
{
int16_t mv_j[2] = { mv_cand[1].mv[reflist][0], mv_cand[1].mv[reflist][1] };
mv_cand[candidates].mv[reflist][0] = mv_j[0];
mv_cand[candidates].mv[reflist][1] = mv_j[1];
mv_cand[candidates].ref[reflist] = ref_j;
}
}
mv_cand[candidates].dir = inter_dir;
if (inter_dir > 0)
{
candidates++;
}
}
if (candidates == max_num_cands) return candidates;
int num_ref = state->frame->ref->used_size;
if (candidates < max_num_cands && state->frame->slicetype == KVZ_SLICE_B) {

View file

@ -34,11 +34,13 @@
typedef struct {
uint8_t dir;
uint8_t ref[2]; // index to L0/L1
int16_t mv[2][2];
uint16_t mer[2];
uint8_t dir;
uint8_t ref[2]; // index to L0/L1
/// \brief Flag for half-pel mv, otherwise mv is full-pel
bool half_pel;
} inter_merge_cand_t;
void kvz_inter_recon_cu(const encoder_state_t * const state,

View file

@ -1676,8 +1676,8 @@ static void search_pu_inter(encoder_state_t * const state,
// Check motion vector constraints and perform rough search
for (int merge_idx = 0; merge_idx < info.num_merge_cand; ++merge_idx) {
inter_merge_cand_t *cur_cand = &info.merge_cand[merge_idx];
if (cur_cand->half_pel) continue; // Skip half-pel candidates for now TODO: Fix
cur_cu->inter.mv_dir = cur_cand->dir;
cur_cu->inter.mv_ref[0] = cur_cand->ref[0];
cur_cu->inter.mv_ref[1] = cur_cand->ref[1];