[avx2] Replace inefficient loop with AVX2 code

This commit is contained in:
Joose Sainio 2023-05-10 09:25:58 +03:00
parent bc24601369
commit f2fb641acb

View file

@ -1005,26 +1005,58 @@ static INLINE void update_states_avx2(
if (state->m_remRegBins[state_id] >= 4) { if (state->m_remRegBins[state_id] >= 4) {
state->m_remRegBins[state_id] -= (decisions->absLevel[decision_id] < 2 ? (unsigned)decisions->absLevel[decision_id] : 3); state->m_remRegBins[state_id] -= (decisions->absLevel[decision_id] < 2 ? (unsigned)decisions->absLevel[decision_id] : 3);
} }
for (int k = 0; k < 16; ++k) {
state->m_ctxInit[state_offset >> 2][k * 4 + i] = state->m_ctxInit[ctxs->m_prev_state_offset >> 2][k * 4 + decisions->prevId[decision_id]];
}
for (int k = 0; k < 16; ++k) {
state->m_absLevels[state_offset >> 2][k * 4 + i] = state->m_absLevels[ctxs->m_prev_state_offset >> 2][k * 4 + decisions->prevId[decision_id]];
}
} else { } else {
state->m_numSigSbb[state_id] = 1; state->m_numSigSbb[state_id] = 1;
state->m_refSbbCtxId[state_id] = -1; state->m_refSbbCtxId[state_id] = -1;
int ctxBinSampleRatio = 28; int ctxBinSampleRatio = 28;
//(scanInfo.chType == CHANNEL_TYPE_LUMA) ? MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_LUMA : MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_CHROMA; //(scanInfo.chType == CHANNEL_TYPE_LUMA) ? MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_LUMA : MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_CHROMA;
state->m_remRegBins[state_id] = (state->effWidth * state->effHeight * ctxBinSampleRatio) / 16 - (decisions->absLevel[decision_id] < 2 ? (unsigned)decisions->absLevel[decision_id] : 3); state->m_remRegBins[state_id] = (state->effWidth * state->effHeight * ctxBinSampleRatio) / 16 - (decisions->absLevel[decision_id] < 2 ? (unsigned)decisions->absLevel[decision_id] : 3);
for (int k = i; k < 64; k += 4) {
state->m_ctxInit[state_offset >> 2][k] = 0;
state->m_absLevels[state_offset >> 2][k] = 0;
}
} }
rem_reg_all_gte_4 &= state->m_remRegBins[state_id] >= 4; rem_reg_all_gte_4 &= state->m_remRegBins[state_id] >= 4;
rem_reg_all_lt4 &= state->m_remRegBins[state_id] < 4; rem_reg_all_lt4 &= state->m_remRegBins[state_id] < 4;
} }
{
__m256i prev_state_full = _mm256_load_si256((__m256i const*)decisions->prevId);
__m256i shuffle_mask = _mm256_setr_epi8(0, 4, 8, 12, 0, 4, 8, 12, 0, 4, 8, 12, 0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1);
prev_state_full = _mm256_shuffle_epi8(prev_state_full, shuffle_mask);
prev_state_full = _mm256_permute4x64_epi64(prev_state_full, 0);
__m256i temp_add = _mm256_setr_epi32(
0,
0x04040404,
0x08080808,
0x0c0c0c0c,
0,
0x04040404,
0x08080808,
0x0c0c0c0c);
__m256i comp_mask = _mm256_cmpeq_epi8(prev_state_full, _mm256_set1_epi64x(-1));
prev_state_full = _mm256_add_epi8(prev_state_full, temp_add);
prev_state_full = _mm256_blendv_epi8(prev_state_full, _mm256_set1_epi64x(-1), comp_mask);
for (int i = 0; i < 64; i += (256 / (8 * sizeof(uint8_t)))) {
__m256i data = _mm256_load_si256((__m256i*)&state->m_absLevels[ctxs->m_prev_state_offset >> 2][i]);
data = _mm256_shuffle_epi8(data, prev_state_full);
_mm256_store_si256((__m256i*)&state->m_absLevels[ctxs->m_curr_state_offset >> 2][i], data);
}
}
{
__m256i prev_state_full = _mm256_load_si256((__m256i const*)decisions->prevId);
__m256i shuffle_mask = _mm256_setr_epi8(0, 0, 4, 4,8, 8, 12, 12, 0, 0, 4, 4, 8, 8, 12, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1);
prev_state_full = _mm256_shuffle_epi8(prev_state_full, shuffle_mask);
prev_state_full = _mm256_permute4x64_epi64(prev_state_full, 0);
__m256i comp_mask = _mm256_cmpeq_epi8(prev_state_full, _mm256_set1_epi64x(-1));
prev_state_full = _mm256_slli_epi16(prev_state_full, 1);
__m256i temp_add = _mm256_setr_epi8(0, 1, 0, 1, 0, 1, 0, 1, 8, 9, 8, 9, 8, 9, 8, 9, 16, 17, 16, 17,16, 17,16, 17, 24, 25,24,25,24,25,24,25);
prev_state_full = _mm256_add_epi8(prev_state_full, temp_add);
prev_state_full = _mm256_blendv_epi8(prev_state_full, _mm256_set1_epi64x(-1), comp_mask);
for (int i = 0; i < 64; i += (256 / 8 / sizeof(uint16_t))) {
__m256i data = _mm256_load_si256((__m256i*)(&state->m_ctxInit[(ctxs->m_prev_state_offset >> 2)][i]));
data = _mm256_shuffle_epi8(data, prev_state_full);
_mm256_store_si256((__m256i*)(&state->m_ctxInit[(state_offset >> 2)][i]), data);
}
}
} }
uint32_t level_offset = scan_pos & 15; uint32_t level_offset = scan_pos & 15;
__m128i max_abs = _mm_min_epi32(abs_level, _mm_set1_epi32(51)); __m128i max_abs = _mm_min_epi32(abs_level, _mm_set1_epi32(51));