[avx2] Simplify

This commit is contained in:
Joose Sainio 2023-04-26 10:34:41 +03:00
parent 2811ce58f4
commit b4c84e820c
2 changed files with 34 additions and 54 deletions

View file

@ -656,7 +656,7 @@ void uvg_dep_quant_update_state_eos(
} }
else if (decisions->prevId[decision_id] >= 0) { else if (decisions->prevId[decision_id] >= 0) {
prvState = ctxs->m_prev_state_offset + decisions->prevId[decision_id]; prvState = ctxs->m_prev_state_offset + decisions->prevId[decision_id];
state->m_numSigSbb[curr_state_offset] = state->m_numSigSbb[prvState] + !!decisions->absLevel[decision_id]; state->m_numSigSbb[curr_state_offset] = state->m_numSigSbb[prvState] || !!decisions->absLevel[decision_id];
memcpy(state->m_absLevelsAndCtxInit[curr_state_offset], state->m_absLevelsAndCtxInit[prvState], 16 * sizeof(uint8_t)); memcpy(state->m_absLevelsAndCtxInit[curr_state_offset], state->m_absLevelsAndCtxInit[prvState], 16 * sizeof(uint8_t));
} }
else { else {

View file

@ -158,29 +158,20 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
if (spt == SCAN_ISCSBB) { if (spt == SCAN_ISCSBB) {
__m256i original = _mm256_loadu_si256((__m256i const*)state->m_sigFracBits[start]); __m256i original = _mm256_loadu_si256((__m256i const*)state->m_sigFracBits[start]);
__m256i even_mask = _mm256_setr_epi32(0, 2, 4, 6, -1, -1, -1, -1); __m256i even = _mm256_and_si256(original, _mm256_set1_epi64x(0xffffffff));
__m256i odd_mask = _mm256_setr_epi32(1, 3, 5, 7, -1, -1, -1, -1); __m256i odd = _mm256_srli_epi64(original, 32);
__m256i even = _mm256_permutevar8x32_epi32(original, even_mask); rd_cost_a = _mm256_add_epi64(rd_cost_a, odd);
__m256i odd = _mm256_permutevar8x32_epi32(original, odd_mask); rd_cost_b = _mm256_add_epi64(rd_cost_b, odd);
__m256i even_64 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(even, 0)); rd_cost_z = _mm256_add_epi64(rd_cost_z, even);
__m256i odd_64 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(odd, 0));
rd_cost_a = _mm256_add_epi64(rd_cost_a, odd_64);
rd_cost_b = _mm256_add_epi64(rd_cost_b, odd_64);
rd_cost_z = _mm256_add_epi64(rd_cost_z, even_64);
} else if (spt == SCAN_SOCSBB) { } else if (spt == SCAN_SOCSBB) {
__m256i original = _mm256_loadu_si256((__m256i const*)state->m_sigFracBits[start]); __m256i original = _mm256_loadu_si256((__m256i const*)state->m_sigFracBits[start]);
__m256i even_mask = _mm256_setr_epi32(0, 2, 4, 6, -1, -1, -1, -1);
__m256i odd_mask = _mm256_setr_epi32(1, 3, 5, 7, -1, -1, -1, -1); __m256i m_sigFracBits_0 = _mm256_and_si256(original, _mm256_set1_epi64x(0xffffffff));
__m256i even = _mm256_permutevar8x32_epi32(original, even_mask); __m256i m_sigFracBits_1 = _mm256_srli_epi64(original, 32);
__m256i odd = _mm256_permutevar8x32_epi32(original, odd_mask);
__m256i m_sigFracBits_0 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(even, 0));
__m256i m_sigFracBits_1 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(odd, 0));
original = _mm256_loadu_si256((__m256i const*)state->m_sbbFracBits[start]); original = _mm256_loadu_si256((__m256i const*)state->m_sbbFracBits[start]);
odd = _mm256_permutevar8x32_epi32(original, odd_mask); __m256i m_sbbFracBits_1 = _mm256_srli_epi64(original, 32);
__m256i m_sbbFracBits_1 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(odd, 0));
rd_cost_a = _mm256_add_epi64(rd_cost_a, m_sbbFracBits_1); rd_cost_a = _mm256_add_epi64(rd_cost_a, m_sbbFracBits_1);
rd_cost_b = _mm256_add_epi64(rd_cost_b, m_sbbFracBits_1); rd_cost_b = _mm256_add_epi64(rd_cost_b, m_sbbFracBits_1);
rd_cost_z = _mm256_add_epi64(rd_cost_z, m_sbbFracBits_1); rd_cost_z = _mm256_add_epi64(rd_cost_z, m_sbbFracBits_1);
@ -190,19 +181,17 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
rd_cost_z = _mm256_add_epi64(rd_cost_z, m_sigFracBits_0); rd_cost_z = _mm256_add_epi64(rd_cost_z, m_sigFracBits_0);
} }
else { else {
if (state->m_numSigSbb[start] && state->m_numSigSbb[start + 1] && state->m_numSigSbb[start + 2] && state->m_numSigSbb[start + 3]) { int num_sig_sbb;
memcpy(&num_sig_sbb, &state->m_numSigSbb[start], 4);
if (num_sig_sbb == 0x01010101) {
__m256i original = _mm256_loadu_si256((__m256i const*)state->m_sigFracBits[start]); __m256i original = _mm256_loadu_si256((__m256i const*)state->m_sigFracBits[start]);
__m256i even_mask = _mm256_setr_epi32(0, 2, 4, 6, -1, -1, -1, -1); __m256i even = _mm256_and_si256(original, _mm256_set1_epi64x(0xffffffff));
__m256i odd_mask = _mm256_setr_epi32(1, 3, 5, 7, -1, -1, -1, -1); __m256i odd = _mm256_srli_epi64(original, 32);
__m256i even = _mm256_permutevar8x32_epi32(original, even_mask); rd_cost_a = _mm256_add_epi64(rd_cost_a, odd);
__m256i odd = _mm256_permutevar8x32_epi32(original, odd_mask); rd_cost_b = _mm256_add_epi64(rd_cost_b, odd);
__m256i even_64 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(even, 0)); rd_cost_z = _mm256_add_epi64(rd_cost_z, even);
__m256i odd_64 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(odd, 0));
rd_cost_a = _mm256_add_epi64(rd_cost_a, odd_64);
rd_cost_b = _mm256_add_epi64(rd_cost_b, odd_64);
rd_cost_z = _mm256_add_epi64(rd_cost_z, even_64);
} }
else if (!state->m_numSigSbb[start] && !state->m_numSigSbb[start + 1] && !state->m_numSigSbb[start + 2] && !state->m_numSigSbb[start + 3]) { else if (num_sig_sbb == 0) {
rd_cost_z = _mm256_setr_epi64x(decisions->rdCost[0], decisions->rdCost[0], decisions->rdCost[3], decisions->rdCost[3]); rd_cost_z = _mm256_setr_epi64x(decisions->rdCost[0], decisions->rdCost[0], decisions->rdCost[3], decisions->rdCost[3]);
} }
@ -527,7 +516,7 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
__m128i prev_state_with_ff_high_bytes = _mm_or_si128(prev_state, _mm_set1_epi32(0xffffff00)); __m128i prev_state_with_ff_high_bytes = _mm_or_si128(prev_state, _mm_set1_epi32(0xffffff00));
__m128i num_sig_sbb = _mm_load_si128((const __m128i*)state->m_numSigSbb); __m128i num_sig_sbb = _mm_load_si128((const __m128i*)state->m_numSigSbb);
num_sig_sbb = _mm_shuffle_epi8(num_sig_sbb, prev_state_with_ff_high_bytes); num_sig_sbb = _mm_shuffle_epi8(num_sig_sbb, prev_state_with_ff_high_bytes);
num_sig_sbb = _mm_add_epi32( num_sig_sbb = _mm_or_si128(
num_sig_sbb, num_sig_sbb,
_mm_min_epi32(abs_level, _mm_set1_epi32(1)) _mm_min_epi32(abs_level, _mm_set1_epi32(1))
); );
@ -552,7 +541,7 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
memset(state->m_absLevelsAndCtxInit[curr_state_offset], 0, 16 * sizeof(uint8_t)); memset(state->m_absLevelsAndCtxInit[curr_state_offset], 0, 16 * sizeof(uint8_t));
} else if (decisions->prevId[decision_id] >= 0) { } else if (decisions->prevId[decision_id] >= 0) {
prev_state_s[i] = ctxs->m_prev_state_offset + decisions->prevId[decision_id]; prev_state_s[i] = ctxs->m_prev_state_offset + decisions->prevId[decision_id];
state->m_numSigSbb[curr_state_offset] = state->m_numSigSbb[prev_state_s[i]] + !!decisions->absLevel[decision_id]; state->m_numSigSbb[curr_state_offset] = state->m_numSigSbb[prev_state_s[i]] || !!decisions->absLevel[decision_id];
memcpy(state->m_absLevelsAndCtxInit[curr_state_offset], state->m_absLevelsAndCtxInit[prev_state_s[i]], 16 * sizeof(uint8_t)); memcpy(state->m_absLevelsAndCtxInit[curr_state_offset], state->m_absLevelsAndCtxInit[prev_state_s[i]], 16 * sizeof(uint8_t));
} else { } else {
state->m_numSigSbb[curr_state_offset] = 1; state->m_numSigSbb[curr_state_offset] = 1;
@ -591,7 +580,7 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
memset(sbbFlags, 0, numSbb * sizeof(uint8_t)); memset(sbbFlags, 0, numSbb * sizeof(uint8_t));
memset(levels + scan_pos, 0, setCpSize); memset(levels + scan_pos, 0, setCpSize);
} }
sbbFlags[cg_pos] = !!ctxs->m_allStates.m_numSigSbb[curr_state + state_offset]; sbbFlags[cg_pos] = ctxs->m_allStates.m_numSigSbb[curr_state + state_offset];
memcpy(levels + scan_pos, ctxs->m_allStates.m_absLevelsAndCtxInit[curr_state + state_offset], 16 * sizeof(uint8_t)); memcpy(levels + scan_pos, ctxs->m_allStates.m_absLevelsAndCtxInit[curr_state + state_offset], 16 * sizeof(uint8_t));
} }
@ -996,7 +985,7 @@ static INLINE void update_states_avx2(
); );
sum_num = _mm_add_epi32( sum_num = _mm_add_epi32(
sum_num, sum_num,
_mm_min_epi32(_mm_and_si128(t, first_byte), ones)); _mm_min_epi32(t, ones));
} }
case 4: case 4:
{ {
@ -1013,9 +1002,7 @@ static INLINE void update_states_avx2(
sum_abs1, sum_abs1,
min_arg min_arg
); );
sum_num = _mm_add_epi32( sum_num = _mm_add_epi32(sum_num, _mm_min_epi32(t, ones));
sum_num,
_mm_min_epi32(_mm_and_si128(t, first_byte), ones));
} }
case 3: case 3:
{ {
@ -1032,9 +1019,7 @@ static INLINE void update_states_avx2(
sum_abs1, sum_abs1,
min_arg min_arg
); );
sum_num = _mm_add_epi32( sum_num = _mm_add_epi32(sum_num, _mm_min_epi32(t, ones));
sum_num,
_mm_min_epi32(_mm_and_si128(t, first_byte), ones));
} }
case 2: case 2:
{ {
@ -1051,9 +1036,7 @@ static INLINE void update_states_avx2(
sum_abs1, sum_abs1,
min_arg min_arg
); );
sum_num = _mm_add_epi32( sum_num = _mm_add_epi32(sum_num, _mm_min_epi32(t, ones));
sum_num,
_mm_min_epi32(_mm_and_si128(t, first_byte), ones));
} }
case 1: { case 1: {
__m128i t = _mm_i32gather_epi32( __m128i t = _mm_i32gather_epi32(
@ -1069,9 +1052,7 @@ static INLINE void update_states_avx2(
sum_abs1, sum_abs1,
min_arg min_arg
); );
sum_num = _mm_add_epi32( sum_num = _mm_add_epi32(sum_num, _mm_min_epi32(t, ones));
sum_num,
_mm_min_epi32(_mm_and_si128(t, first_byte), ones));
} break; } break;
default: default:
assert(0); assert(0);
@ -1161,6 +1142,7 @@ static INLINE void update_states_avx2(
} }
else if (rem_reg_all_lt4) { else if (rem_reg_all_lt4) {
const __m128i first_byte = _mm_set1_epi32(0xff);
uint8_t* levels = (uint8_t*)state->m_absLevelsAndCtxInit[state_offset]; uint8_t* levels = (uint8_t*)state->m_absLevelsAndCtxInit[state_offset];
const __m128i last_two_bytes = _mm_set1_epi32(0xffff); const __m128i last_two_bytes = _mm_set1_epi32(0xffff);
const __m128i last_byte = _mm_set1_epi32(0xff); const __m128i last_byte = _mm_set1_epi32(0xff);
@ -1173,21 +1155,23 @@ static INLINE void update_states_avx2(
2); 2);
tinit = _mm_and_si128(tinit, last_two_bytes); tinit = _mm_and_si128(tinit, last_two_bytes);
__m128i sum_abs = _mm_srli_epi32(tinit, 8); __m128i sum_abs = _mm_srli_epi32(tinit, 8);
sum_abs = _mm_min_epi32(sum_abs, _mm_set1_epi32(51));
switch (numIPos) { switch (numIPos) {
case 5: { case 5: {
__m128i t = _mm_i32gather_epi32( __m128i t = _mm_i32gather_epi32(
(int*)levels, (int*)levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[4])), _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[4])),
1); 1);
t = _mm_and_si128(t, last_byte);
sum_abs = _mm_add_epi32(sum_abs, t); sum_abs = _mm_add_epi32(sum_abs, t);
// Need this to make sure we don't go beyond 255
sum_abs = _mm_and_si128(sum_abs, first_byte);
sum_abs = _mm_min_epi32(sum_abs, _mm_set1_epi32(51));
} }
case 4: { case 4: {
__m128i t = _mm_i32gather_epi32( __m128i t = _mm_i32gather_epi32(
(int*)levels, (int*)levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[3])), _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[3])),
1); 1);
t = _mm_and_si128(t, last_byte);
sum_abs = _mm_add_epi32(sum_abs, t); sum_abs = _mm_add_epi32(sum_abs, t);
} }
case 3: { case 3: {
@ -1195,7 +1179,6 @@ static INLINE void update_states_avx2(
(int*)levels, (int*)levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[2])), _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[2])),
1); 1);
t = _mm_and_si128(t, last_byte);
sum_abs = _mm_add_epi32(sum_abs, t); sum_abs = _mm_add_epi32(sum_abs, t);
} }
case 2: { case 2: {
@ -1203,7 +1186,6 @@ static INLINE void update_states_avx2(
(int*)levels, (int*)levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[1])), _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[1])),
1); 1);
t = _mm_and_si128(t, last_byte);
sum_abs = _mm_add_epi32(sum_abs, t); sum_abs = _mm_add_epi32(sum_abs, t);
} }
case 1: { case 1: {
@ -1211,12 +1193,12 @@ static INLINE void update_states_avx2(
(int*)levels, (int*)levels,
_mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[0])), _mm_add_epi32(levels_start_offsets, _mm_set1_epi32(next_nb_info_ssb.inPos[0])),
1); 1);
t = _mm_and_si128(t, last_byte);
sum_abs = _mm_add_epi32(sum_abs, t); sum_abs = _mm_add_epi32(sum_abs, t);
} break; } break;
default: default:
assert(0); assert(0);
} }
sum_abs = _mm_and_si128(sum_abs, last_byte);
if (extRiceFlag) { if (extRiceFlag) {
assert(0 && "Not implemented for avx2"); assert(0 && "Not implemented for avx2");
} else { } else {
@ -1229,10 +1211,8 @@ static INLINE void update_states_avx2(
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
state->m_goRiceZero[state_offset + i] = (i < 2 ? 1 : 2) << state->m_goRicePar[state_offset + i]; state->m_goRiceZero[state_offset + i] = (i < 2 ? 1 : 2) << state->m_goRicePar[state_offset + i];
} }
} }
} }