[avx2] Improve avx2 version of update_common_context

This commit is contained in:
Joose Sainio 2023-05-09 11:28:23 +03:00
parent 9280d35d96
commit 8d02ff8e4d

View file

@ -497,7 +497,7 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
__m128i prev_state;
__m128i prev_state_no_offset;
__m128i abs_level = _mm_load_si128((const __m128i*)decisions->absLevel);
__m128i control = _mm_setr_epi8(0, 4, 8, 12, 0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1);
__m128i control = _mm_setr_epi8(0, 4, 8, 12, 0, 4, 8, 12, 0, 4, 8, 12, 0, 4, 8, 12);
if (all_above_four) {
prev_state = _mm_set1_epi32(ctxs->m_skip_state_offset);
prev_state_no_offset = _mm_sub_epi32(_mm_load_si128((const __m128i*)decisions->prevId), _mm_set1_epi32(4));
@ -575,34 +575,101 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
{
const uint32_t numSbb = width_in_sbb * height_in_sbb;
common_context* cc = &ctxs->m_common_context;
size_t setCpSize = cc->m_nbInfo[scan_pos - 1].maxDist * sizeof(uint8_t);
int previous_state_array[4];
size_t setCpSize = cc->m_nbInfo[scan_pos - 1].maxDist * sizeof(uint8_t);
uint8_t* sbbFlags = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].sbbFlags;
uint8_t* levels = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].levels + scan_pos * 4;
uint8_t* levels_in = cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset].levels + scan_pos * 4;
int previous_state_array[4];
_mm_storeu_si128((__m128i*)previous_state_array, prev_state);
for (int curr_state = 0; curr_state < 4; ++curr_state) {
uint8_t* sbbFlags = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset ].sbbFlags;
uint8_t* levels = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].levels;
const int p_state = previous_state_array[curr_state];
if (p_state != -1 && ctxs->m_allStates.m_refSbbCtxId[p_state] >= 0) {
const int prev_sbb = ctxs->m_allStates.m_refSbbCtxId[p_state];
for (int i = 0; i < numSbb; ++i) {
sbbFlags[i * 4 + curr_state] = cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset].sbbFlags[i * 4 + prev_sbb];
}
for (int i = 16; i < setCpSize; ++i) {
levels[scan_pos * 4 + i * 4 + curr_state] = cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset].levels[scan_pos * 4 + i * 4 + prev_sbb];
if (all_have_previous_state) {
__m128i temp_p_state = _mm_shuffle_epi8(prev_state, control);
__m128i ref_sbb_ctx_offset =
_mm_load_si128((__m128i*)ctxs->m_allStates.m_refSbbCtxId);
ref_sbb_ctx_offset = _mm_shuffle_epi8(ref_sbb_ctx_offset, temp_p_state);
if (numSbb <= 4) {
__m128i incremented_ref_sbb_ctx_offset = _mm_add_epi8(
ref_sbb_ctx_offset,
_mm_setr_epi8(0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12)
);
__m128i blend_mask = _mm_cmpeq_epi8(ref_sbb_ctx_offset, _mm_set1_epi32(0xffffffff));
__m128i sbb_flags = _mm_loadu_si128((__m128i*)cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset].sbbFlags);
sbb_flags = _mm_shuffle_epi8(sbb_flags, incremented_ref_sbb_ctx_offset);
sbb_flags = _mm_blendv_epi8(sbb_flags, _mm_set1_epi64x(0), blend_mask);
if (numSbb == 2) {
uint64_t temp = _mm_extract_epi64(sbb_flags, 0);
memcpy(sbbFlags, &temp, 8);
} else {
_mm_storeu_si128((__m128i*)sbbFlags, sbb_flags);
}
} else {
for (int i = 0; i < numSbb; ++i) {
sbbFlags[i * 4 + curr_state] = 0;
}
for (int i = 16; i < setCpSize; ++i) {
levels[scan_pos * 4 + i * 4 + curr_state] = 0;
__m256i extended_ref_state = _mm256_zextsi128_si256(ref_sbb_ctx_offset);
extended_ref_state = _mm256_permute4x64_epi64(extended_ref_state, 0);
__m256i inc_ref_state = _mm256_add_epi8(
extended_ref_state,
_mm256_setr_epi32(0, 0x04040404, 0x08080808, 0x0c0c0c0c,0, 0x04040404, 0x08080808, 0x0c0c0c0c)
);
__m256i blend_mask = _mm256_cmpeq_epi8(extended_ref_state, _mm256_set1_epi32(0xffffffff));
inc_ref_state = _mm256_blendv_epi8(inc_ref_state, _mm256_set1_epi32(0xffffffff), blend_mask);
for (int i = 0; i < numSbb * 4; i += 32) {
__m256i sbb_flags = _mm256_loadu_si256((__m256i*)(&cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset].sbbFlags[i]));
sbb_flags = _mm256_shuffle_epi8(sbb_flags, inc_ref_state);
_mm256_store_si256((__m256i*)&sbbFlags[i], sbb_flags);
}
}
sbbFlags[cg_pos * 4 + curr_state] = ctxs->m_allStates.m_numSigSbb[curr_state + state_offset];
for (int i = 0; i < 16; ++i) {
levels[scan_pos * 4 + i * 4 + curr_state] = ctxs->m_allStates.m_absLevels[state_offset / 4][i * 4 + curr_state];
int levels_start = 16;
const uint64_t limit = setCpSize & ~(8 - 1);
if (levels_start < limit) {
__m256i extended_ref_state = _mm256_zextsi128_si256(ref_sbb_ctx_offset);
extended_ref_state = _mm256_permute4x64_epi64(extended_ref_state, 0);
__m256i inc_ref_state = _mm256_add_epi8(
extended_ref_state,
_mm256_setr_epi32(0, 0x04040404, 0x08080808, 0x0c0c0c0c,0, 0x04040404, 0x08080808, 0x0c0c0c0c)
);
__m256i blend_mask = _mm256_cmpeq_epi8(extended_ref_state, _mm256_set1_epi32(0xffffffff));
inc_ref_state = _mm256_blendv_epi8(inc_ref_state, _mm256_set1_epi32(0xffffffff), blend_mask);
for (; levels_start < limit; levels_start += 8) {
__m256i levels_v = _mm256_loadu_si256((__m256i*)(&levels_in[levels_start * 4]));
levels_v = _mm256_shuffle_epi8(levels_v, inc_ref_state);
_mm256_store_si256((__m256i*)&levels[levels_start * 4], levels_v);
}
}
uint8_t ref_sbb[4];
int temp_sbb_ref = _mm_extract_epi32(ref_sbb_ctx_offset, 0);
memcpy(ref_sbb, &temp_sbb_ref, 4);
for (;levels_start < setCpSize; ++levels_start) {
uint8_t new_values[4];
new_values[0] = ref_sbb[0] != 0xff ? levels_in[levels_start * 4 + ref_sbb[0]] : 0;
new_values[1] = ref_sbb[1] != 0xff ? levels_in[levels_start * 4 + ref_sbb[1]] : 0;
new_values[2] = ref_sbb[2] != 0xff ? levels_in[levels_start * 4 + ref_sbb[2]] : 0;
new_values[3] = ref_sbb[3] != 0xff ? levels_in[levels_start * 4 + ref_sbb[3]] : 0;
memcpy(&levels[levels_start * 4], new_values, 4);
}
}
else {
for (int curr_state = 0; curr_state < 4; ++curr_state) {
const int p_state = previous_state_array[curr_state];
if (p_state != -1 && ctxs->m_allStates.m_refSbbCtxId[p_state] >= 0) {
const int prev_sbb = ctxs->m_allStates.m_refSbbCtxId[p_state];
for (int i = 0; i < numSbb; ++i) {
sbbFlags[i * 4 + curr_state] = cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset].sbbFlags[i * 4 + prev_sbb];
}
for (int i = 16; i < setCpSize; ++i) {
levels[i * 4 + curr_state] = levels_in[i * 4 + prev_sbb];
}
} else {
for (int i = 0; i < numSbb; ++i) {
sbbFlags[i * 4 + curr_state] = 0;
}
for (int i = 16; i < setCpSize; ++i) {
levels[ i * 4 + curr_state] = 0;
}
}
}
}
memcpy(levels, ctxs->m_allStates.m_absLevels[state_offset / 4], 64);
memcpy(&sbbFlags[cg_pos * 4], &ctxs->m_allStates.m_numSigSbb[state_offset], 4);
__m128i sbb_right = next_sbb_right ?
_mm_cvtepi8_epi32(_mm_loadu_si128((__m128i*)&cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].sbbFlags[next_sbb_right * 4])) :
@ -640,8 +707,6 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
const NbInfoOut* nbOut = cc->m_nbInfo + scanBeg;
const uint8_t* absLevels = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].levels + scanBeg * 4;
__m128i levels_offsets = _mm_set_epi32(cc->num_coeff * 3, cc->num_coeff * 2, cc->num_coeff * 1, 0);
__m128i first_byte = _mm_set1_epi32(0xff);
__m128i ones = _mm_set1_epi32(1);
__m128i fours = _mm_set1_epi32(4);
__m256i all[4];
@ -738,10 +803,8 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
_mm256_storeu_si256((__m256i*)(&state->m_ctxInit[state_offset >> 2][16]), all[1]);
_mm256_storeu_si256((__m256i*)(&state->m_ctxInit[state_offset >> 2][32]), all[2]);
_mm256_storeu_si256((__m256i*)(&state->m_ctxInit[state_offset >> 2][48]), all[3]);
for (int i = 0; i < 4; ++i) {
memset(state->m_absLevels[state_offset >> 2], 0, 16 * 4);
}
memset(state->m_absLevels[state_offset >> 2], 0, 16 * 4);
}
__m128i sum_num = _mm_and_si128(last, _mm_set1_epi32(7));
@ -973,10 +1036,8 @@ static INLINE void update_states_avx2(
state->all_lt_four = rem_reg_all_lt4;
if (rem_reg_all_gte_4) {
const __m128i first_byte = _mm_set1_epi32(0xff);
const __m128i ones = _mm_set1_epi32(1);
const uint32_t tinit_offset = MIN(level_offset - 1u, 15u);
const __m128i levels_start_offsets = _mm_set_epi32(16 * 3, 16 * 2, 16 * 1, 16 * 0);
__m128i tinit = _mm_loadu_si128((__m128i*)(&state->m_ctxInit[state_offset >> 2][tinit_offset * 4]));
tinit = _mm_cvtepi16_epi32(tinit);
__m128i sum_abs1 = _mm_and_si128(_mm_srli_epi32(tinit, 3), _mm_set1_epi32(31));
@ -1119,11 +1180,8 @@ static INLINE void update_states_avx2(
}
else if (rem_reg_all_lt4) {
const __m128i first_byte = _mm_set1_epi32(0xff);
uint8_t* levels = (uint8_t*)state->m_absLevels[state_offset >> 2];
const __m128i last_byte = _mm_set1_epi32(0xff);
const uint32_t tinit_offset = MIN(level_offset - 1u, 15u);
const __m128i levels_start_offsets = _mm_set_epi32(16 * 3, 16 * 2, 16 * 1, 16 * 0);
__m128i tinit = _mm_loadu_si128((__m128i*)(&state->m_ctxInit[state_offset >> 2][tinit_offset * 4]));
tinit = _mm_cvtepi16_epi32(tinit);
__m128i sum_abs = _mm_srli_epi32(tinit, 8);