[lfnst] working for 32x32

This commit is contained in:
Joose Sainio 2022-06-16 08:34:38 +03:00
parent b75ce57fce
commit d7f7a2d99b
9 changed files with 151 additions and 53 deletions

View file

@ -347,8 +347,7 @@ static double search_intra_trdepth(
for (int i = start_idx; i < end_idx + 1; ++i) {
search_data->lfnst_costs[i] = MAX_DOUBLE;
}
bool constraints[2] = { pred_cu->violates_lfnst_constrained_luma,
pred_cu->lfnst_last_scan_pos };
for (int lfnst_idx = start_idx; lfnst_idx <= end_idx; lfnst_idx++) {
// Initialize lfnst variables
@ -359,6 +358,8 @@ static double search_intra_trdepth(
for (trafo = mts_start; trafo < num_transforms; trafo++) {
pred_cu->tr_idx = trafo;
bool constraints[2] = { pred_cu->violates_lfnst_constrained_luma,
pred_cu->lfnst_last_scan_pos };
if (mts_enabled) {
pred_cu->mts_last_scan_pos = 0;
pred_cu->violates_mts_coeff_constraint = 0;
@ -520,8 +521,8 @@ static double search_intra_trdepth(
pred_cu->tr_skip = best_tr_idx == MTS_SKIP;
pred_cu->tr_idx = best_tr_idx;
pred_cu->lfnst_idx = best_lfnst_idx;
pred_cu->lfnst_last_scan_pos = constraints[1];
pred_cu->violates_lfnst_constrained_luma = constraints[0];
pred_cu->lfnst_last_scan_pos = false;
pred_cu->violates_lfnst_constrained_luma = false;
nosplit_cost += best_rd_cost;
// Early stop condition for the recursive search.
@ -1435,7 +1436,7 @@ int8_t uvg_search_intra_chroma_rdo(
int lfnst_modes_to_check[3];
if(chroma_data->pred_cu.lfnst_idx) {
lfnst_modes_to_check[0] = chroma_data->pred_cu.lfnst_idx;
lfnst_modes_to_check[1] = 0;
lfnst_modes_to_check[1] = -1;
lfnst_modes_to_check[2] = -1;
}
else {
@ -1457,10 +1458,10 @@ int8_t uvg_search_intra_chroma_rdo(
uint8_t best_lfnst_index = 0;
for (int lfnst_i = 0; lfnst_i < 3; ++lfnst_i) {
const int lfnst = lfnst_modes_to_check[lfnst_i];
chroma_data[mode_i].lfnst_costs[lfnst] += mode_bits * state->lambda;
if (lfnst == -1) {
continue;
}
chroma_data[mode_i].lfnst_costs[lfnst] += mode_bits * state->lambda;
if (pred_cu->tr_depth == pred_cu->depth) {
uvg_intra_predict(
state,

View file

@ -1206,7 +1206,7 @@ const int16_t* uvg_g_mts_input[2][3][5] = {
},
};
static void mts_dct_4x4_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth)
static void mts_dct_4x4_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth, uint8_t lfnst_idx)
{
//const int height = 4;
const int width = 4;
@ -1229,7 +1229,7 @@ static void mts_dct_4x4_avx2(const int16_t* input, int16_t* output, tr_type_t ty
_mm256_store_si256((__m256i*)output, result);
}
static void mts_idct_4x4_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth)
static void mts_idct_4x4_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth, uint8_t lfnst_idx)
{
int32_t shift_1st = 7;
int32_t shift_2nd = 12 - (bitdepth - 8);
@ -1247,7 +1247,7 @@ static void mts_idct_4x4_avx2(const int16_t* input, int16_t* output, tr_type_t t
_mm256_store_si256((__m256i*)output, result);
}
static void mts_dct_8x8_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth)
static void mts_dct_8x8_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth, uint8_t lfnst_idx)
{
int32_t shift_1st = uvg_g_convert_to_bit[8] + 1 + (bitdepth - 8);
int32_t shift_2nd = uvg_g_convert_to_bit[8] + 8;
@ -1261,7 +1261,7 @@ static void mts_dct_8x8_avx2(const int16_t* input, int16_t* output, tr_type_t ty
matmul_8x8_a_bt(dct2, tmpres, output, shift_2nd);
}
static void mts_idct_8x8_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth)
static void mts_idct_8x8_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth, uint8_t lfnst_idx)
{
int32_t shift_1st = 7;
int32_t shift_2nd = 12 - (bitdepth - 8);
@ -1275,7 +1275,7 @@ static void mts_idct_8x8_avx2(const int16_t* input, int16_t* output, tr_type_t t
}
static void mts_dct_16x16_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth)
static void mts_dct_16x16_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth, uint8_t lfnst_idx)
{
int32_t shift_1st = uvg_g_convert_to_bit[16] + 1 + (bitdepth - 8);
int32_t shift_2nd = uvg_g_convert_to_bit[16] + 8;
@ -1306,6 +1306,25 @@ static void mts_dct_16x16_avx2(const int16_t* input, int16_t* output, tr_type_t
// multiply completely
matmul_16x16_a_bt(d_v, i_v, tmp, shift_1st);
matmul_16x16_a_bt(d_v2, tmp, o_v, shift_2nd);
const int skip_line = lfnst_idx ? 8 : 0;
const int skip_line2 = lfnst_idx ? 8 : 0;
if (skip_line)
{
const int reduced_line = 8, cutoff = 8;
int16_t* dst2 = output + reduced_line;
for (int j = 0; j < cutoff; j++)
{
memset(dst2, 0, sizeof(int16_t) * skip_line);
dst2 += 16;
}
}
if (skip_line2)
{
int16_t* dst2 = output + 16 * 8;
memset(dst2, 0, sizeof(int16_t) * 16 * skip_line2);
}
}
/**********/
@ -1395,7 +1414,7 @@ static void partial_butterfly_inverse_16_mts_avx2(const int16_t* src, int16_t* d
}
}
static void mts_idct_16x16_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth)
static void mts_idct_16x16_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth, uint8_t lfnst_idx)
{
int32_t shift_1st = 7;
int32_t shift_2nd = 12 - (bitdepth - 8);
@ -1510,7 +1529,7 @@ static void mul_clip_matrix_32x32_mts_avx2(const int16_t* left,
}
}
static void mts_dct_32x32_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth)
static void mts_dct_32x32_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth, uint8_t lfnst_idx)
{
int32_t shift_1st = uvg_g_convert_to_bit[32] + 1 + (bitdepth - 8);
int32_t shift_2nd = uvg_g_convert_to_bit[32] + 8;
@ -1519,15 +1538,19 @@ static void mts_dct_32x32_avx2(const int16_t* input, int16_t* output, tr_type_t
const int16_t* tdct = uvg_g_mts_input[1][type_hor][3];
const int16_t* dct = uvg_g_mts_input[0][type_ver][3];
const int skip_width = (type_hor != DCT2) ? 16 : 0;
const int skip_height = (type_ver != DCT2) ? 16 : 0;
int skip_width = (type_hor != DCT2) ? 16 : 0;
int skip_height = (type_ver != DCT2) ? 16 : 0;
if(lfnst_idx) {
skip_width = 24;
skip_height = 24;
}
mul_clip_matrix_32x32_mts_avx2(input, tdct, tmp, shift_1st, skip_width, 0 );
mul_clip_matrix_32x32_mts_avx2(dct, tmp, output, shift_2nd, skip_width, skip_height);
}
static void mts_idct_32x32_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth)
static void mts_idct_32x32_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth, uint8_t lfnst_idx)
{
int32_t shift_1st = 7;
int32_t shift_2nd = 12 - (bitdepth - 8);
@ -1542,7 +1565,7 @@ static void mts_idct_32x32_avx2(const int16_t* input, int16_t* output, tr_type_t
mul_clip_matrix_32x32_mts_avx2(tmp, dct, output, shift_2nd, 0, 0);
}
typedef void tr_func(const int16_t*, int16_t*, tr_type_t , tr_type_t , uint8_t);
typedef void tr_func(const int16_t*, int16_t*, tr_type_t , tr_type_t , uint8_t, uint8_t);
// ToDo: Enable MTS 2x2 and 64x64 transforms
static tr_func* dct_table[5] = {
@ -1576,7 +1599,7 @@ static void mts_dct_avx2(
uvg_get_tr_type(width, color, tu, &type_hor, &type_ver, mts_idx);
if (type_hor == DCT2 && type_ver == DCT2)
if (type_hor == DCT2 && type_ver == DCT2 && !tu->lfnst_idx)
{
dct_func* dct_func = uvg_get_dct_func(width, color, tu->type);
dct_func(bitdepth, input, output);
@ -1587,7 +1610,7 @@ static void mts_dct_avx2(
tr_func* dct = dct_table[log2_width_minus2];
dct(input, output, type_hor, type_ver, bitdepth);
dct(input, output, type_hor, type_ver, bitdepth, tu->lfnst_idx);
}
}
@ -1617,7 +1640,7 @@ static void mts_idct_avx2(
tr_func* idct = idct_table[log2_width_minus2];
idct(input, output, type_hor, type_ver, bitdepth);
idct(input, output, type_hor, type_ver, bitdepth, tu->lfnst_idx);
}
}

View file

@ -669,7 +669,7 @@ int uvg_quantize_residual_avx2(encoder_state_t *const state,
}
else {
uvg_quant(state, coeff, coeff_out, width, width, color,
scan_order, cur_cu->type, cur_cu->tr_idx == MTS_SKIP && color == COLOR_Y);
scan_order, cur_cu->type, cur_cu->tr_idx == MTS_SKIP && color == COLOR_Y, cur_cu->lfnst_idx);
}
// Check if there are any non-zero coefficients.

View file

@ -1366,6 +1366,11 @@ static void fastForwardDCT2_B32(const int16_t* src, int16_t* dst, int32_t shift,
dst += line;
}
}
if (skip_line2) {
const int reduced_line = line - skip_line2;
dst = p_coef + reduced_line * 32;
memset(dst, 0, skip_line2 * 32 * sizeof(coeff_t));
}
}
static void fastInverseDCT2_B32(const int16_t* src, int16_t* dst, int32_t shift, int line, int skip_line, int skip_line2)
@ -2491,7 +2496,7 @@ static void mts_dct_generic(
uvg_get_tr_type(width, color, tu, &type_hor, &type_ver, mts_idx);
if (type_hor == DCT2 && type_ver == DCT2)
if (type_hor == DCT2 && type_ver == DCT2 && !tu->lfnst_idx)
{
dct_func *dct_func = uvg_get_dct_func(width, color, tu->type);
dct_func(bitdepth, input, output);
@ -2499,9 +2504,21 @@ static void mts_dct_generic(
else
{
const int height = width;
const int skip_width = (type_hor != DCT2 && width == 32) ? 16 : (width > 32 ? width - 32 : 0);
const int skip_height = (type_ver != DCT2 && height == 32) ? 16 : (height > 32 ? height - 32 : 0);
int skip_width = (type_hor != DCT2 && width == 32) ? 16 : (width > 32 ? width - 32 : 0);
int skip_height = (type_ver != DCT2 && height == 32) ? 16 : (height > 32 ? height - 32 : 0);
const int log2_width_minus2 = uvg_g_convert_to_bit[width];
if(tu->lfnst_idx) {
if ((width == 4 && height > 4) || (width > 4 && height == 4))
{
skip_width = width - 4;
skip_height = height - 4;
}
else if ((width >= 8 && height >= 8))
{
skip_width = width - 8;
skip_height = height - 8;
}
}
partial_tr_func* dct_hor = dct_table[type_hor][log2_width_minus2];
partial_tr_func* dct_ver = dct_table[type_ver][log2_width_minus2];

View file

@ -49,8 +49,17 @@
* \brief quantize transformed coefficents
*
*/
void uvg_quant_generic(const encoder_state_t * const state, coeff_t *coef, coeff_t *q_coef, int32_t width,
int32_t height, color_t color, int8_t scan_idx, int8_t block_type, int8_t transform_skip)
void uvg_quant_generic(
const encoder_state_t * const state,
coeff_t *coef,
coeff_t *q_coef,
int32_t width,
int32_t height,
color_t color,
int8_t scan_idx,
int8_t block_type,
int8_t transform_skip,
uint8_t lfnst_idx)
{
const encoder_control_t * const encoder = state->encoder_control;
const uint32_t log2_block_size = uvg_g_convert_to_bit[width] + 2;
@ -69,20 +78,41 @@ void uvg_quant_generic(const encoder_state_t * const state, coeff_t *coef, coeff
uint32_t ac_sum = 0;
for (int32_t n = 0; n < width * height; n++) {
int32_t level = coef[n];
int64_t abs_level = (int64_t)abs(level);
int32_t sign;
if(lfnst_idx == 0){
for (int32_t n = 0; n < width * height; n++) {
int32_t level = coef[n];
int64_t abs_level = (int64_t)abs(level);
int32_t sign;
sign = (level < 0 ? -1 : 1);
sign = (level < 0 ? -1 : 1);
int32_t curr_quant_coeff = quant_coeff[n];
level = (int32_t)((abs_level * curr_quant_coeff + add) >> q_bits);
ac_sum += level;
int32_t curr_quant_coeff = quant_coeff[n];
level = (int32_t)((abs_level * curr_quant_coeff + add) >> q_bits);
ac_sum += level;
level *= sign;
q_coef[n] = (coeff_t)(CLIP(-32768, 32767, level));
level *= sign;
q_coef[n] = (coeff_t)(CLIP(-32768, 32767, level));
}
}
else {
const int max_number_of_coeffs = ((width == 4 && height == 4) || (width == 8 && height == 8)) ? 8 : 16;
memset(q_coef, 0, width * height * sizeof(coeff_t));
for (int32_t n = 0; n < max_number_of_coeffs; n++) {
const uint32_t idx = scan[n];
int32_t level = coef[idx];
int64_t abs_level = (int64_t)abs(level);
int32_t sign;
sign = (level < 0 ? -1 : 1);
int32_t curr_quant_coeff = quant_coeff[n];
level = (abs_level * curr_quant_coeff + add) >> q_bits;
ac_sum += level;
level *= sign;
q_coef[idx] = (coeff_t)(CLIP(-32768, 32767, level));
}
}
// Signhiding
@ -90,13 +120,27 @@ void uvg_quant_generic(const encoder_state_t * const state, coeff_t *coef, coeff
int32_t delta_u[LCU_WIDTH*LCU_WIDTH >> 2];
for (int32_t n = 0; n < width * height; n++) {
int32_t level = coef[n];
int64_t abs_level = (int64_t)abs(level);
int32_t curr_quant_coeff = quant_coeff[n];
if(lfnst_idx == 0) {
for (int32_t n = 0; n < width * height; n++) {
int32_t level = coef[n];
int64_t abs_level = (int64_t)abs(level);
int32_t curr_quant_coeff = quant_coeff[n];
level = (int32_t)((abs_level * curr_quant_coeff + add) >> q_bits);
delta_u[n] = (int32_t)((abs_level * curr_quant_coeff - (level << q_bits)) >> q_bits8);
level = (int32_t)((abs_level * curr_quant_coeff + add) >> q_bits);
delta_u[n] = (int32_t)((abs_level * curr_quant_coeff - (level << q_bits)) >> q_bits8);
}
}
else {
const int max_number_of_coeffs = ((width == 4 && height == 4) || (width == 8 && height == 8)) ? 8 : 16;
for (int32_t n = 0; n < max_number_of_coeffs; n++) {
const uint32_t idx = scan[n];
int32_t level = coef[idx];
int64_t abs_level = (int64_t)abs(level);
int32_t curr_quant_coeff = quant_coeff[idx];
level = (abs_level * curr_quant_coeff + add) >> q_bits;
delta_u[idx] = (int32_t)((abs_level * curr_quant_coeff - (level << q_bits)) >> q_bits8);
}
}
if (ac_sum >= 2) {
@ -277,7 +321,7 @@ int uvg_quant_cbcr_residual_generic(
}
else {
uvg_quant(state, coeff, coeff_out, width, width, cur_cu->joint_cb_cr == 1 ? COLOR_V : COLOR_U,
scan_order, cur_cu->type, cur_cu->tr_idx == MTS_SKIP && false);
scan_order, cur_cu->type, cur_cu->tr_idx == MTS_SKIP && false, cur_cu->lfnst_idx);
}
int8_t has_coeffs = 0;
@ -455,7 +499,7 @@ int uvg_quantize_residual_generic(encoder_state_t *const state,
} else {
uvg_quant(state, coeff, coeff_out, width, width, color,
scan_order, cur_cu->type, cur_cu->tr_idx == MTS_SKIP && color == COLOR_Y);
scan_order, cur_cu->type, cur_cu->tr_idx == MTS_SKIP && color == COLOR_Y, cur_cu->lfnst_idx);
}
// Check if there are any non-zero coefficients.

View file

@ -47,8 +47,17 @@
#define QUANT_SHIFT 14
int uvg_strategy_register_quant_generic(void* opaque, uint8_t bitdepth);
void uvg_quant_generic(const encoder_state_t * const state, coeff_t *coef, coeff_t *q_coef, int32_t width,
int32_t height, color_t color, int8_t scan_idx, int8_t block_type, int8_t transform_skip);
void uvg_quant_generic(
const encoder_state_t * const state,
coeff_t *coef,
coeff_t *q_coef,
int32_t width,
int32_t height,
color_t color,
int8_t scan_idx,
int8_t block_type,
int8_t transform_skip,
uint8_t lfnst_idx);
int uvg_quantize_residual_generic(encoder_state_t *const state,
const cu_info_t *const cur_cu, const int width, const color_t color,

View file

@ -46,7 +46,7 @@
// Declare function pointers.
typedef unsigned (quant_func)(const encoder_state_t * const state, coeff_t *coef, coeff_t *q_coef, int32_t width,
int32_t height, color_t color, int8_t scan_idx, int8_t block_type, int8_t transform_skip);
int32_t height, color_t color, int8_t scan_idx, int8_t block_type, int8_t transform_skip, uint8_t lfnst_idx);
typedef unsigned (quant_cbcr_func)(
encoder_state_t* const state,
const cu_info_t* const cur_cu,

View file

@ -246,7 +246,7 @@ void uvg_transform2d(const encoder_control_t * const encoder,
color_t color,
const cu_info_t *tu)
{
if (encoder->cfg.mts)
if (encoder->cfg.mts || tu->lfnst_idx)
{
uvg_mts_dct(encoder->bitdepth, color, tu, block_size, block, coeff, encoder->cfg.mts);
}
@ -412,7 +412,8 @@ static void quantize_chroma(
coeff_t v_quant_coeff[1024],
const coeff_scan_order_t scan_order,
bool* u_has_coeffs,
bool* v_has_coeffs)
bool* v_has_coeffs,
uint8_t lfnst_idx)
{
if (state->encoder_control->cfg.rdoq_enable &&
(transforms[i] != CHROMA_TS || !state->encoder_control->cfg.rdoq_skip))
@ -442,11 +443,11 @@ static void quantize_chroma(
}
else {
uvg_quant(state, &u_coeff[i * trans_offset], u_quant_coeff, width, height, transforms[i] != JCCR_1 ? COLOR_U : COLOR_V,
scan_order, CU_INTRA, transforms[i] == CHROMA_TS);
scan_order, CU_INTRA, transforms[i] == CHROMA_TS, lfnst_idx);
if (!IS_JCCR_MODE(transforms[i])) {
uvg_quant(state, &v_coeff[i * trans_offset], v_quant_coeff, width, height, COLOR_V,
scan_order, CU_INTRA, transforms[i] == CHROMA_TS);
scan_order, CU_INTRA, transforms[i] == CHROMA_TS, lfnst_idx);
}
}
@ -550,7 +551,9 @@ void uvg_chroma_transform_search(
v_quant_coeff,
scan_order,
&u_has_coeffs,
&v_has_coeffs);
&v_has_coeffs,
pred_cu->lfnst_idx);
/*
if(pred_cu->type == CU_INTRA && transforms[i] != CHROMA_TS) {
bool constraints[2] = { false, false };
uvg_derive_lfnst_constraints(pred_cu, depth, constraints, &u_coeff[i * trans_offset], width, height);
@ -558,7 +561,7 @@ void uvg_chroma_transform_search(
uvg_derive_lfnst_constraints(pred_cu, depth, constraints, &v_coeff[i * trans_offset], width, height);
}
if ((constraints[0] || !constraints[1]) && pred_cu->lfnst_idx != 0) continue;
}
}*/
if (IS_JCCR_MODE(transforms[i]) && !u_has_coeffs) continue;

View file

@ -157,6 +157,7 @@ TEST dct(void)
cu_info_t tu;
tu.type = CU_INTRA;
tu.tr_idx = MTS_DST7_DST7 + trafo;
tu.lfnst_idx = 0;
int16_t* buf = dct_bufs[trafo * NUM_SIZES + blocksize];
ALIGNED(32) int16_t test_result[LCU_WIDTH * LCU_WIDTH] = { 0 };