From f2fb641acb9ae2a3dc6cc039d372ee1972a1a907 Mon Sep 17 00:00:00 2001 From: Joose Sainio Date: Wed, 10 May 2023 09:25:58 +0300 Subject: [PATCH] [avx2] Replace inefficient loop with AVX2 code --- src/strategies/avx2/depquant-avx2.c | 52 +++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 10 deletions(-) diff --git a/src/strategies/avx2/depquant-avx2.c b/src/strategies/avx2/depquant-avx2.c index cacee3fd..82ddd498 100644 --- a/src/strategies/avx2/depquant-avx2.c +++ b/src/strategies/avx2/depquant-avx2.c @@ -1005,26 +1005,58 @@ static INLINE void update_states_avx2( if (state->m_remRegBins[state_id] >= 4) { state->m_remRegBins[state_id] -= (decisions->absLevel[decision_id] < 2 ? (unsigned)decisions->absLevel[decision_id] : 3); } - for (int k = 0; k < 16; ++k) { - state->m_ctxInit[state_offset >> 2][k * 4 + i] = state->m_ctxInit[ctxs->m_prev_state_offset >> 2][k * 4 + decisions->prevId[decision_id]]; - } - for (int k = 0; k < 16; ++k) { - state->m_absLevels[state_offset >> 2][k * 4 + i] = state->m_absLevels[ctxs->m_prev_state_offset >> 2][k * 4 + decisions->prevId[decision_id]]; - } } else { state->m_numSigSbb[state_id] = 1; state->m_refSbbCtxId[state_id] = -1; int ctxBinSampleRatio = 28; //(scanInfo.chType == CHANNEL_TYPE_LUMA) ? MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_LUMA : MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_CHROMA; state->m_remRegBins[state_id] = (state->effWidth * state->effHeight * ctxBinSampleRatio) / 16 - (decisions->absLevel[decision_id] < 2 ? (unsigned)decisions->absLevel[decision_id] : 3); - for (int k = i; k < 64; k += 4) { - state->m_ctxInit[state_offset >> 2][k] = 0; - state->m_absLevels[state_offset >> 2][k] = 0; - } } rem_reg_all_gte_4 &= state->m_remRegBins[state_id] >= 4; rem_reg_all_lt4 &= state->m_remRegBins[state_id] < 4; } + { + __m256i prev_state_full = _mm256_load_si256((__m256i const*)decisions->prevId); + __m256i shuffle_mask = _mm256_setr_epi8(0, 4, 8, 12, 0, 4, 8, 12, 0, 4, 8, 12, 0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); + prev_state_full = _mm256_shuffle_epi8(prev_state_full, shuffle_mask); + prev_state_full = _mm256_permute4x64_epi64(prev_state_full, 0); + __m256i temp_add = _mm256_setr_epi32( + 0, + 0x04040404, + 0x08080808, + 0x0c0c0c0c, + 0, + 0x04040404, + 0x08080808, + 0x0c0c0c0c); + __m256i comp_mask = _mm256_cmpeq_epi8(prev_state_full, _mm256_set1_epi64x(-1)); + prev_state_full = _mm256_add_epi8(prev_state_full, temp_add); + prev_state_full = _mm256_blendv_epi8(prev_state_full, _mm256_set1_epi64x(-1), comp_mask); + for (int i = 0; i < 64; i += (256 / (8 * sizeof(uint8_t)))) { + __m256i data = _mm256_load_si256((__m256i*)&state->m_absLevels[ctxs->m_prev_state_offset >> 2][i]); + data = _mm256_shuffle_epi8(data, prev_state_full); + _mm256_store_si256((__m256i*)&state->m_absLevels[ctxs->m_curr_state_offset >> 2][i], data); + } + } + + { + __m256i prev_state_full = _mm256_load_si256((__m256i const*)decisions->prevId); + __m256i shuffle_mask = _mm256_setr_epi8(0, 0, 4, 4,8, 8, 12, 12, 0, 0, 4, 4, 8, 8, 12, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); + prev_state_full = _mm256_shuffle_epi8(prev_state_full, shuffle_mask); + prev_state_full = _mm256_permute4x64_epi64(prev_state_full, 0); + __m256i comp_mask = _mm256_cmpeq_epi8(prev_state_full, _mm256_set1_epi64x(-1)); + prev_state_full = _mm256_slli_epi16(prev_state_full, 1); + __m256i temp_add = _mm256_setr_epi8(0, 1, 0, 1, 0, 1, 0, 1, 8, 9, 8, 9, 8, 9, 8, 9, 16, 17, 16, 17,16, 17,16, 17, 24, 25,24,25,24,25,24,25); + + prev_state_full = _mm256_add_epi8(prev_state_full, temp_add); + prev_state_full = _mm256_blendv_epi8(prev_state_full, _mm256_set1_epi64x(-1), comp_mask); + + for (int i = 0; i < 64; i += (256 / 8 / sizeof(uint16_t))) { + __m256i data = _mm256_load_si256((__m256i*)(&state->m_ctxInit[(ctxs->m_prev_state_offset >> 2)][i])); + data = _mm256_shuffle_epi8(data, prev_state_full); + _mm256_store_si256((__m256i*)(&state->m_ctxInit[(state_offset >> 2)][i]), data); + } + } } uint32_t level_offset = scan_pos & 15; __m128i max_abs = _mm_min_epi32(abs_level, _mm_set1_epi32(51));