[tmvp] Round TMVP colocated MVs using float conversion from VTM

This commit is contained in:
Marko Viitanen 2021-11-24 21:57:46 +02:00
parent 37405ffdee
commit 83430a49f2

View file

@ -38,6 +38,7 @@
#include "encoder.h" #include "encoder.h"
#include "imagelist.h" #include "imagelist.h"
#include "kvz_math.h"
#include "strategies/generic/picture-generic.h" #include "strategies/generic/picture-generic.h"
#include "strategies/strategies-ipol.h" #include "strategies/strategies-ipol.h"
#include "videoframe.h" #include "videoframe.h"
@ -1098,6 +1099,47 @@ static INLINE int16_t get_scaled_mv(int16_t mv, int scale)
return CLIP(-32768, 32767, (scaled + 127 + (scaled < 0)) >> 8); return CLIP(-32768, 32767, (scaled + 127 + (scaled < 0)) >> 8);
} }
#define MV_EXPONENT_BITCOUNT 4
#define MV_MANTISSA_BITCOUNT 6
#define MV_MANTISSA_UPPER_LIMIT ((1 << (MV_MANTISSA_BITCOUNT - 1)) - 1)
#define MV_MANTISSA_LIMIT (1 << (MV_MANTISSA_BITCOUNT - 1))
#define MV_EXPONENT_MASK ((1 << MV_EXPONENT_BITCOUNT) - 1)
static int convert_mv_fixed_to_float(int32_t val)
{
int sign = val >> 31;
int scale = kvz_math_floor_log2((val ^ sign) | MV_MANTISSA_UPPER_LIMIT) - (MV_MANTISSA_BITCOUNT - 1);
int exponent;
int mantissa;
if (scale >= 0)
{
int round = (1 << scale) >> 1;
int n = (val + round) >> scale;
exponent = scale + ((n ^ sign) >> (MV_MANTISSA_BITCOUNT - 1));
mantissa = (n & MV_MANTISSA_UPPER_LIMIT) | (sign << (MV_MANTISSA_BITCOUNT - 1));
}
else
{
exponent = 0;
mantissa = val;
}
return exponent | (mantissa << MV_EXPONENT_BITCOUNT);
}
static int convert_mv_float_to_fixed(int val)
{
int exponent = val & MV_EXPONENT_MASK;
int mantissa = val >> MV_EXPONENT_BITCOUNT;
return exponent == 0 ? mantissa : (mantissa ^ MV_MANTISSA_LIMIT) << (exponent - 1);
}
static int round_mv_comp(int x)
{
return convert_mv_float_to_fixed(convert_mv_fixed_to_float(x));
}
static void apply_mv_scaling_pocs(int32_t current_poc, static void apply_mv_scaling_pocs(int32_t current_poc,
int32_t current_ref_poc, int32_t current_ref_poc,
int32_t neighbor_poc, int32_t neighbor_poc,
@ -1187,6 +1229,10 @@ static bool add_temporal_candidate(const encoder_state_t *state,
mv_out[0] = colocated->inter.mv[col_list][0]; mv_out[0] = colocated->inter.mv[col_list][0];
mv_out[1] = colocated->inter.mv[col_list][1]; mv_out[1] = colocated->inter.mv[col_list][1];
mv_out[0] = round_mv_comp(mv_out[0]);
mv_out[1] = round_mv_comp(mv_out[1]);
apply_mv_scaling_pocs( apply_mv_scaling_pocs(
state->frame->poc, state->frame->poc,
state->frame->ref->pocs[current_ref], state->frame->ref->pocs[current_ref],