[avx2] update_states_avx2 working

This commit is contained in:
Joose Sainio 2023-04-12 10:41:37 +03:00
parent 58a66c0654
commit 8f4c3cecbf

View file

@ -158,11 +158,14 @@ typedef struct
int8_t m_goRiceZero[12];
int8_t m_stateId[12];
uint32_t m_sigFracBitsArray[12][12][2];
int32_t *m_gtxFracBitsArray[21];
int32_t m_gtxFracBitsArray[21][6];
common_context* m_commonCtx;
unsigned effWidth;
unsigned effHeight;
bool all_gte_four;
bool all_lt_four;
} all_depquant_states;
typedef struct
@ -577,14 +580,8 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
rd_cost_a = _mm256_add_epi64(rd_cost_a, pq_a_delta_dist);
rd_cost_b = _mm256_add_epi64(rd_cost_b, pq_b_delta_dist);
bool all_over_or_four = true;
bool all_under_four = true;
for (int i = 0; i < 4; i++) {
all_over_or_four &= state->m_remRegBins[start + i] >= 4;
all_under_four &= state->m_remRegBins[start + i] < 4;
}
if (all_over_or_four) {
if (state->all_gte_four) {
if (pqDataA->absLevel[0] < 4 && pqDataA->absLevel[3] < 4) {
__m128i offsets = _mm_set_epi32(18 + pqDataA->absLevel[3], 12 + pqDataA->absLevel[3], 6 + pqDataA->absLevel[0], 0 + pqDataA->absLevel[0]);
__m128i coeff_frac_bits = _mm_i32gather_epi32(&state->m_coeffFracBits[start][0], offsets, 4);
@ -737,7 +734,7 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
_mm256_storeu_epi64(temp_rd_cost_a, rd_cost_a);
_mm256_storeu_epi64(temp_rd_cost_b, rd_cost_b);
_mm256_storeu_epi64(temp_rd_cost_z, rd_cost_z);
} else if (all_under_four) {
} else if (state->all_lt_four) {
__m128i scale_bits = _mm_set1_epi32(1 << SCALE_BITS);
__m128i max_rice = _mm_set1_epi32(31);
__m128i go_rice_zero = _mm_cvtepi8_epi32(_mm_loadu_epi8(&state->m_goRiceZero[start]));
@ -1274,6 +1271,8 @@ static INLINE void update_states_avx2(
all_minus_one &= decisions->prevId[i] == -1;
}
int state_offset = ctxs->m_curr_state_offset;
__m256i rd_cost = _mm256_loadu_epi64(decisions->rdCost);
_mm256_storeu_epi64(&ctxs->m_allStates.m_rdCost[state_offset], rd_cost);
if (all_above_minus_two) {
bool rem_reg_all_gte_4 = true;
@ -1312,7 +1311,7 @@ static INLINE void update_states_avx2(
memcpy(&state->m_goRicePar[state_offset], &go_rice_par_i, 4);
__m256i sbb_frac_bits = _mm256_i32gather_epi64(state->m_sbbFracBits, prv_states, 4);
__m256i sbb_frac_bits = _mm256_i32gather_epi64(state->m_sbbFracBits, prv_states, 8);
_mm256_storeu_epi64(&state->m_sbbFracBits[state_offset][0], sbb_frac_bits);
__m128i rem_reg_bins = _mm_i32gather_epi32(state->m_remRegBins, prv_states, 4);
@ -1321,7 +1320,7 @@ static INLINE void update_states_avx2(
__m128i reg_bins_sub = _mm_set1_epi32(0);
__m128i abs_level_smaller_than_two = _mm_cmplt_epi32(abs_level, _mm_set1_epi32(2));
__m128i secondary = _mm_blendv_epi8(abs_level, _mm_set1_epi32(3), abs_level_smaller_than_two);
__m128i secondary = _mm_blendv_epi8(_mm_set1_epi32(3), abs_level, abs_level_smaller_than_two);
__m128i rem_reg_bins_smaller_than_four = _mm_cmplt_epi32(rem_reg_bins, _mm_set1_epi32(4));
reg_bins_sub = _mm_blendv_epi8(secondary, reg_bins_sub, rem_reg_bins_smaller_than_four);
@ -1336,7 +1335,7 @@ static INLINE void update_states_avx2(
rem_reg_all_lt4 = (bit_mask == 0xFFFF);
for (int i = 0; i < 4; ++i) {
memcpy(state->m_absLevelsAndCtxInit[i], state->m_absLevelsAndCtxInit[prv_states_scalar[i]], 48 * sizeof(uint8_t));
memcpy(state->m_absLevelsAndCtxInit[state_offset + i], state->m_absLevelsAndCtxInit[prv_states_scalar[i]], 48 * sizeof(uint8_t));
}
}
else if (all_minus_one) {
@ -1347,8 +1346,8 @@ static INLINE void update_states_avx2(
__m128i rem_reg_bins = _mm_set1_epi32(a);
__m128i sub = _mm_blendv_epi8(
abs_level,
_mm_set1_epi32(3),
abs_level,
_mm_cmplt_epi32(abs_level, _mm_set1_epi32(2))
);
rem_reg_bins = _mm_sub_epi32(rem_reg_bins, sub);
@ -1400,18 +1399,20 @@ static INLINE void update_states_avx2(
uint8_t* levels = (uint8_t*)state->m_absLevelsAndCtxInit[state_offset + i];
levels[level_offset] = max_abs_s[i];
}
state->all_gte_four = rem_reg_all_gte_4;
state->all_lt_four = rem_reg_all_lt4;
if (rem_reg_all_gte_4) {
const __m128i last_two_bytes = _mm_set1_epi32(0xffff);
const __m128i last_byte = _mm_set1_epi32(0xff);
const __m128i first_two_bytes = _mm_set1_epi32(0xffff);
const __m128i first_byte = _mm_set1_epi32(0xff);
const __m128i ones = _mm_set1_epi32(1);
const uint32_t tinit_offset = MIN(level_offset - 1u, 15u) + 8;
const __m128i levels_start_offsets = _mm_set_epi32(48 * 3, 48 * 2, 48 * 1, 48 * 0);
const __m128i ctx_start_offsets = _mm_srli_epi32(levels_start_offsets, 1);
__m128i tinit = _mm_i32gather_epi32(
state->m_absLevelsAndCtxInit[state_offset],
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(tinit_offset)),
1);
tinit = _mm_and_epi32(tinit, last_two_bytes);
_mm_add_epi32(ctx_start_offsets, _mm_set1_epi32(tinit_offset)),
2);
tinit = _mm_and_epi32(tinit, first_two_bytes);
__m128i sum_abs1 = _mm_and_epi32(_mm_srli_epi32(tinit, 3), _mm_set1_epi32(31));
__m128i sum_num = _mm_and_epi32(tinit, _mm_set1_epi32(7));
@ -1423,12 +1424,18 @@ static INLINE void update_states_avx2(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[4])),
1);
t = _mm_and_epi32(t, first_byte);
__m128i min_arg = _mm_min_epi32(
_mm_add_epi32(_mm_set1_epi32(4), _mm_and_epi32(t, ones)),
t
);
sum_abs1 = _mm_add_epi32(
sum_abs1,
_mm_and_epi32(t, ones));
min_arg
);
sum_num = _mm_add_epi32(
sum_num,
_mm_min_epi32(_mm_and_epi32(t, last_byte), ones));
_mm_min_epi32(_mm_and_epi32(t, first_byte), ones));
}
case 4:
{
@ -1436,12 +1443,18 @@ static INLINE void update_states_avx2(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[3])),
1);
t = _mm_and_epi32(t, first_byte);
__m128i min_arg = _mm_min_epi32(
_mm_add_epi32(_mm_set1_epi32(4), _mm_and_epi32(t, ones)),
t
);
sum_abs1 = _mm_add_epi32(
sum_abs1,
_mm_and_epi32(t, ones));
min_arg
);
sum_num = _mm_add_epi32(
sum_num,
_mm_min_epi32(_mm_and_epi32(t, last_byte), ones));
_mm_min_epi32(_mm_and_epi32(t, first_byte), ones));
}
case 3:
{
@ -1449,12 +1462,18 @@ static INLINE void update_states_avx2(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[2])),
1);
t = _mm_and_epi32(t, first_byte);
__m128i min_arg = _mm_min_epi32(
_mm_add_epi32(_mm_set1_epi32(4), _mm_and_epi32(t, ones)),
t
);
sum_abs1 = _mm_add_epi32(
sum_abs1,
_mm_and_epi32(t, ones));
min_arg
);
sum_num = _mm_add_epi32(
sum_num,
_mm_min_epi32(_mm_and_epi32(t, last_byte), ones));
_mm_min_epi32(_mm_and_epi32(t, first_byte), ones));
}
case 2:
{
@ -1462,39 +1481,52 @@ static INLINE void update_states_avx2(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[1])),
1);
t = _mm_and_epi32(t, first_byte);
__m128i min_arg = _mm_min_epi32(
_mm_add_epi32(_mm_set1_epi32(4), _mm_and_epi32(t, ones)),
t
);
sum_abs1 = _mm_add_epi32(
sum_abs1,
_mm_and_epi32(t, ones));
min_arg
);
sum_num = _mm_add_epi32(
sum_num,
_mm_min_epi32(_mm_and_epi32(t, last_byte), ones));
_mm_min_epi32(_mm_and_epi32(t, first_byte), ones));
}
case 1: {
__m128i t = _mm_i32gather_epi32(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[0])),
1);
t = _mm_and_epi32(t, first_byte);
__m128i min_arg = _mm_min_epi32(
_mm_add_epi32(_mm_set1_epi32(4), _mm_and_epi32(t, ones)),
t
);
sum_abs1 = _mm_add_epi32(
sum_abs1,
_mm_and_epi32(t, ones));
min_arg
);
sum_num = _mm_add_epi32(
sum_num,
_mm_min_epi32(_mm_and_epi32(t, last_byte), ones));
_mm_min_epi32(_mm_and_epi32(t, first_byte), ones));
} break;
default:
assert(0);
}
__m128i sum_gt1 = _mm_sub_epi32(sum_abs1, sum_num);
__m128i offsets = _mm_set_epi32(24 * 3, 24 * 2, 24 * 1, 24 * 0);
__m128i offsets = _mm_set_epi32(12 * 3, 12 * 2, 12 * 1, 12 * 0);
offsets = _mm_add_epi32(offsets, _mm_set1_epi32(sigCtxOffsetNext));
__m128i temp = _mm_min_epi32(
_mm_srli_epi32(_mm_add_epi32(sum_abs1, ones), 1),
_mm_set1_epi32(3));
offsets = _mm_add_epi32(offsets, temp);
__m256i sig_frac_bits = _mm256_i32gather_epi64(state->m_sigFracBitsArray[state_offset][0], offsets, 4);
__m256i sig_frac_bits = _mm256_i32gather_epi64(state->m_sigFracBitsArray[state_offset][0], offsets, 8);
_mm256_storeu_epi64(&state->m_sigFracBits[state_offset][0], sig_frac_bits);
sum_gt1 = _mm_min_epi32(sum_gt1, _mm_set1_epi32(4));
sum_gt1 = _mm_add_epi32(sum_gt1, _mm_set1_epi32(gtxCtxOffsetNext));
uint32_t sum_gt1_s[4];
_mm_storeu_epi32(sum_gt1_s, sum_gt1);
for (int i = 0; i < 4; ++i) {
@ -1509,7 +1541,7 @@ static INLINE void update_states_avx2(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[4])),
1);
t = _mm_and_epi32(t, last_byte);
t = _mm_and_epi32(t, first_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
}
case 4:
@ -1518,7 +1550,7 @@ static INLINE void update_states_avx2(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[3])),
1);
t = _mm_and_epi32(t, last_byte);
t = _mm_and_epi32(t, first_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
}
case 3:
@ -1527,7 +1559,7 @@ static INLINE void update_states_avx2(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[2])),
1);
t = _mm_and_epi32(t, last_byte);
t = _mm_and_epi32(t, first_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
}
case 2:
@ -1536,7 +1568,7 @@ static INLINE void update_states_avx2(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[1])),
1);
t = _mm_and_epi32(t, last_byte);
t = _mm_and_epi32(t, first_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
}
case 1:
@ -1545,7 +1577,7 @@ static INLINE void update_states_avx2(
levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[0])),
1);
t = _mm_and_epi32(t, last_byte);
t = _mm_and_epi32(t, first_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
} break;
default:
@ -1560,7 +1592,10 @@ static INLINE void update_states_avx2(
_mm_sub_epi32(sum_abs, _mm_set1_epi32(20))),
_mm_set1_epi32(0));
__m128i temp = _mm_i32gather_epi32(g_goRiceParsCoeff, sum_all, 4);
_mm_storeu_epi32(&state->m_goRicePar[state_offset], temp);
__m128i control = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1);
__m128i go_rice_par = _mm_shuffle_epi8(temp, control);
int go_rice_par_i = _mm_extract_epi32(go_rice_par, 0);
memcpy(&state->m_goRicePar[state_offset], &go_rice_par_i, 4);
}
}
@ -1571,10 +1606,11 @@ static INLINE void update_states_avx2(
const __m128i ones = _mm_set1_epi32(1);
const uint32_t tinit_offset = MIN(level_offset - 1u, 15u) + 8;
const __m128i levels_start_offsets = _mm_set_epi32(48 * 3, 48 * 2, 48 * 1, 48 * 0);
const __m128i ctx_start_offsets = _mm_srli_epi32(levels_start_offsets, 1);
__m128i tinit = _mm_i32gather_epi32(
state->m_absLevelsAndCtxInit[state_offset],
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(tinit_offset)),
1);
_mm_add_epi32(ctx_start_offsets, _mm_set1_epi32(tinit_offset)),
2);
tinit = _mm_and_epi32(tinit, last_two_bytes);
__m128i sum_abs = _mm_srli_epi32(tinit, 8);
switch (numIPos) {
@ -1624,22 +1660,19 @@ static INLINE void update_states_avx2(
if (extRiceFlag) {
assert(0 && "Not implemented for avx2");
} else {
__m128i sum_all = _mm_max_epi32(
_mm_min_epi32(
_mm_set1_epi32(31),
_mm_sub_epi32(sum_abs, _mm_set1_epi32(20))),
_mm_set1_epi32(0));
__m128i sum_all = _mm_min_epi32(_mm_set1_epi32(31), sum_abs);
__m128i temp = _mm_i32gather_epi32(g_goRiceParsCoeff, sum_all, 4);
__m128i control = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1);
__m128i go_rice_par = _mm_shuffle_epi8(temp, control);
int go_rice_par_i = _mm_extract_epi32(go_rice_par, 0);
memcpy(&state->m_goRicePar[state_offset], &go_rice_par_i, 4);
__m128i go_rice_zero = _mm_set_epi32(2, 2, 1, 1);
go_rice_zero = _mm_sll_epi32(go_rice_zero, temp);
go_rice_zero = _mm_shuffle_epi8(go_rice_zero, control);
int go_rice_zero_i = _mm_extract_epi32(go_rice_par, 0);
memcpy(&state->m_goRiceZero[state_offset], &go_rice_zero_i, 4);
for (int i = 0; i < 4; ++i) {
state->m_goRiceZero[state_offset + i] = (i < 2 ? 1 : 2) << state->m_goRicePar[state_offset + i];
}
}
}
@ -1729,6 +1762,8 @@ static INLINE void update_states_avx2(
}
} else {
for (int i = 0; i < 4; ++i) {
state->all_gte_four = true;
state->all_lt_four = true;
updateState(
ctxs,
numIPos,
@ -1758,7 +1793,7 @@ static INLINE void updateState(
int decision_id) {
all_depquant_states* state = &ctxs->m_allStates;
int state_id = ctxs->m_curr_state_offset + decision_id;
state->m_rdCost[state_id] = decisions->rdCost[decision_id];
// state->m_rdCost[state_id] = decisions->rdCost[decision_id];
if (decisions->prevId[decision_id] > -2) {
if (decisions->prevId[decision_id] >= 0) {
const int prvState = ctxs->m_prev_state_offset + decisions->prevId[decision_id];
@ -1784,7 +1819,8 @@ static INLINE void updateState(
decisions->absLevel[decision_id] < 2 ? (unsigned)decisions->absLevel[decision_id] : 3);
memset(state->m_absLevelsAndCtxInit[state_id], 0, 48 * sizeof(uint8_t));
}
state->all_gte_four &= state->m_remRegBins[state_id] >= 4;
state->all_lt_four &= state->m_remRegBins[state_id] < 4;
uint8_t* levels = (uint8_t*)(state->m_absLevelsAndCtxInit[state_id]);
levels[scan_pos & 15] = (uint8_t)MIN(255, decisions->absLevel[decision_id]);
@ -1860,6 +1896,10 @@ static INLINE void updateState(
state->m_goRiceZero[state_id] = ((state_id & 3) < 2 ? 1 : 2) << state->m_goRicePar[state_id];
}
}
else {
state->all_gte_four &= state->m_remRegBins[state_id] >= 4;
state->all_lt_four &= state->m_remRegBins[state_id] < 4;
}
}
static bool same[13];
@ -1947,17 +1987,6 @@ int uvg_dep_quant(
cur_tu->lfnst_idx :
cur_tu->cr_lfnst_idx;
int8_t t[4] = {2, 2, 2, 2};
__m128i pq_abs_a = _mm_set_epi32(16, 0, 16, 0);
__m128i go_rice_zero = _mm_cvtepi8_epi32(_mm_loadu_epi8(t));
__m128i cmp = _mm_cmplt_epi32(go_rice_zero, pq_abs_a);
__m128i max_rice = _mm_set1_epi32(15);
__m128i go_rice_smaller = _mm_min_epi32(pq_abs_a, max_rice);
__m128i other = _mm_sub_epi32(pq_abs_a, _mm_set1_epi32(1));
__m128i selected = _mm_blendv_epi8(go_rice_zero, other, cmp);
const int numCoeff = width * height;
memset(coeff_out, 0x00, width * height * sizeof(coeff_t));
@ -2055,9 +2084,11 @@ int uvg_dep_quant(
dep_quant_context.m_allStates.effHeight = effectHeight;
dep_quant_context.m_allStates.effWidth = effectWidth;
dep_quant_context.m_allStates.all_gte_four = true;
dep_quant_context.m_allStates.all_lt_four = false;
dep_quant_context.m_allStates.m_commonCtx = &dep_quant_context.m_common_context;
for (int i = 0; i < (compID == COLOR_Y ? 21 : 11); ++i) {
dep_quant_context.m_allStates.m_gtxFracBitsArray[i] = rate_estimator.m_gtxFracBits[i];
memcpy(dep_quant_context.m_allStates.m_gtxFracBitsArray[i], rate_estimator.m_gtxFracBits[i], sizeof(int32_t) * 6);
}
depquant_state_init(&dep_quant_context.m_startState, rate_estimator.m_sigFracBits[0][0], rate_estimator.m_gtxFracBits[0]);