diff --git a/src/strategies/avx2/depquant-avx2.c b/src/strategies/avx2/depquant-avx2.c index 1a00be56..cacee3fd 100644 --- a/src/strategies/avx2/depquant-avx2.c +++ b/src/strategies/avx2/depquant-avx2.c @@ -497,7 +497,7 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos, __m128i prev_state; __m128i prev_state_no_offset; __m128i abs_level = _mm_load_si128((const __m128i*)decisions->absLevel); - __m128i control = _mm_setr_epi8(0, 4, 8, 12, 0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1); + __m128i control = _mm_setr_epi8(0, 4, 8, 12, 0, 4, 8, 12, 0, 4, 8, 12, 0, 4, 8, 12); if (all_above_four) { prev_state = _mm_set1_epi32(ctxs->m_skip_state_offset); prev_state_no_offset = _mm_sub_epi32(_mm_load_si128((const __m128i*)decisions->prevId), _mm_set1_epi32(4)); @@ -575,34 +575,101 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos, { const uint32_t numSbb = width_in_sbb * height_in_sbb; common_context* cc = &ctxs->m_common_context; - size_t setCpSize = cc->m_nbInfo[scan_pos - 1].maxDist * sizeof(uint8_t); - int previous_state_array[4]; + size_t setCpSize = cc->m_nbInfo[scan_pos - 1].maxDist * sizeof(uint8_t); + uint8_t* sbbFlags = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].sbbFlags; + uint8_t* levels = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].levels + scan_pos * 4; + uint8_t* levels_in = cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset].levels + scan_pos * 4; + int previous_state_array[4]; _mm_storeu_si128((__m128i*)previous_state_array, prev_state); - for (int curr_state = 0; curr_state < 4; ++curr_state) { - uint8_t* sbbFlags = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset ].sbbFlags; - uint8_t* levels = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].levels; - const int p_state = previous_state_array[curr_state]; - if (p_state != -1 && ctxs->m_allStates.m_refSbbCtxId[p_state] >= 0) { - const int prev_sbb = ctxs->m_allStates.m_refSbbCtxId[p_state]; - for (int i = 0; i < numSbb; ++i) { - sbbFlags[i * 4 + curr_state] = cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset].sbbFlags[i * 4 + prev_sbb]; - } - for (int i = 16; i < setCpSize; ++i) { - levels[scan_pos * 4 + i * 4 + curr_state] = cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset].levels[scan_pos * 4 + i * 4 + prev_sbb]; + + if (all_have_previous_state) { + __m128i temp_p_state = _mm_shuffle_epi8(prev_state, control); + __m128i ref_sbb_ctx_offset = + _mm_load_si128((__m128i*)ctxs->m_allStates.m_refSbbCtxId); + ref_sbb_ctx_offset = _mm_shuffle_epi8(ref_sbb_ctx_offset, temp_p_state); + if (numSbb <= 4) { + __m128i incremented_ref_sbb_ctx_offset = _mm_add_epi8( + ref_sbb_ctx_offset, + _mm_setr_epi8(0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12) + ); + __m128i blend_mask = _mm_cmpeq_epi8(ref_sbb_ctx_offset, _mm_set1_epi32(0xffffffff)); + __m128i sbb_flags = _mm_loadu_si128((__m128i*)cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset].sbbFlags); + sbb_flags = _mm_shuffle_epi8(sbb_flags, incremented_ref_sbb_ctx_offset); + sbb_flags = _mm_blendv_epi8(sbb_flags, _mm_set1_epi64x(0), blend_mask); + if (numSbb == 2) { + uint64_t temp = _mm_extract_epi64(sbb_flags, 0); + memcpy(sbbFlags, &temp, 8); + } else { + _mm_storeu_si128((__m128i*)sbbFlags, sbb_flags); } } else { - for (int i = 0; i < numSbb; ++i) { - sbbFlags[i * 4 + curr_state] = 0; - } - for (int i = 16; i < setCpSize; ++i) { - levels[scan_pos * 4 + i * 4 + curr_state] = 0; + __m256i extended_ref_state = _mm256_zextsi128_si256(ref_sbb_ctx_offset); + extended_ref_state = _mm256_permute4x64_epi64(extended_ref_state, 0); + __m256i inc_ref_state = _mm256_add_epi8( + extended_ref_state, + _mm256_setr_epi32(0, 0x04040404, 0x08080808, 0x0c0c0c0c,0, 0x04040404, 0x08080808, 0x0c0c0c0c) + ); + __m256i blend_mask = _mm256_cmpeq_epi8(extended_ref_state, _mm256_set1_epi32(0xffffffff)); + inc_ref_state = _mm256_blendv_epi8(inc_ref_state, _mm256_set1_epi32(0xffffffff), blend_mask); + for (int i = 0; i < numSbb * 4; i += 32) { + __m256i sbb_flags = _mm256_loadu_si256((__m256i*)(&cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset].sbbFlags[i])); + sbb_flags = _mm256_shuffle_epi8(sbb_flags, inc_ref_state); + _mm256_store_si256((__m256i*)&sbbFlags[i], sbb_flags); } } - sbbFlags[cg_pos * 4 + curr_state] = ctxs->m_allStates.m_numSigSbb[curr_state + state_offset]; - for (int i = 0; i < 16; ++i) { - levels[scan_pos * 4 + i * 4 + curr_state] = ctxs->m_allStates.m_absLevels[state_offset / 4][i * 4 + curr_state]; + int levels_start = 16; + const uint64_t limit = setCpSize & ~(8 - 1); + if (levels_start < limit) { + __m256i extended_ref_state = _mm256_zextsi128_si256(ref_sbb_ctx_offset); + extended_ref_state = _mm256_permute4x64_epi64(extended_ref_state, 0); + __m256i inc_ref_state = _mm256_add_epi8( + extended_ref_state, + _mm256_setr_epi32(0, 0x04040404, 0x08080808, 0x0c0c0c0c,0, 0x04040404, 0x08080808, 0x0c0c0c0c) + ); + __m256i blend_mask = _mm256_cmpeq_epi8(extended_ref_state, _mm256_set1_epi32(0xffffffff)); + inc_ref_state = _mm256_blendv_epi8(inc_ref_state, _mm256_set1_epi32(0xffffffff), blend_mask); + for (; levels_start < limit; levels_start += 8) { + __m256i levels_v = _mm256_loadu_si256((__m256i*)(&levels_in[levels_start * 4])); + levels_v = _mm256_shuffle_epi8(levels_v, inc_ref_state); + _mm256_store_si256((__m256i*)&levels[levels_start * 4], levels_v); + } + } + uint8_t ref_sbb[4]; + int temp_sbb_ref = _mm_extract_epi32(ref_sbb_ctx_offset, 0); + memcpy(ref_sbb, &temp_sbb_ref, 4); + for (;levels_start < setCpSize; ++levels_start) { + uint8_t new_values[4]; + new_values[0] = ref_sbb[0] != 0xff ? levels_in[levels_start * 4 + ref_sbb[0]] : 0; + new_values[1] = ref_sbb[1] != 0xff ? levels_in[levels_start * 4 + ref_sbb[1]] : 0; + new_values[2] = ref_sbb[2] != 0xff ? levels_in[levels_start * 4 + ref_sbb[2]] : 0; + new_values[3] = ref_sbb[3] != 0xff ? levels_in[levels_start * 4 + ref_sbb[3]] : 0; + memcpy(&levels[levels_start * 4], new_values, 4); + } + + } + else { + for (int curr_state = 0; curr_state < 4; ++curr_state) { + const int p_state = previous_state_array[curr_state]; + if (p_state != -1 && ctxs->m_allStates.m_refSbbCtxId[p_state] >= 0) { + const int prev_sbb = ctxs->m_allStates.m_refSbbCtxId[p_state]; + for (int i = 0; i < numSbb; ++i) { + sbbFlags[i * 4 + curr_state] = cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset].sbbFlags[i * 4 + prev_sbb]; + } + for (int i = 16; i < setCpSize; ++i) { + levels[i * 4 + curr_state] = levels_in[i * 4 + prev_sbb]; + } + } else { + for (int i = 0; i < numSbb; ++i) { + sbbFlags[i * 4 + curr_state] = 0; + } + for (int i = 16; i < setCpSize; ++i) { + levels[ i * 4 + curr_state] = 0; + } + } } } + memcpy(levels, ctxs->m_allStates.m_absLevels[state_offset / 4], 64); + memcpy(&sbbFlags[cg_pos * 4], &ctxs->m_allStates.m_numSigSbb[state_offset], 4); __m128i sbb_right = next_sbb_right ? _mm_cvtepi8_epi32(_mm_loadu_si128((__m128i*)&cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].sbbFlags[next_sbb_right * 4])) : @@ -640,8 +707,6 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos, const NbInfoOut* nbOut = cc->m_nbInfo + scanBeg; const uint8_t* absLevels = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].levels + scanBeg * 4; - __m128i levels_offsets = _mm_set_epi32(cc->num_coeff * 3, cc->num_coeff * 2, cc->num_coeff * 1, 0); - __m128i first_byte = _mm_set1_epi32(0xff); __m128i ones = _mm_set1_epi32(1); __m128i fours = _mm_set1_epi32(4); __m256i all[4]; @@ -738,10 +803,8 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos, _mm256_storeu_si256((__m256i*)(&state->m_ctxInit[state_offset >> 2][16]), all[1]); _mm256_storeu_si256((__m256i*)(&state->m_ctxInit[state_offset >> 2][32]), all[2]); _mm256_storeu_si256((__m256i*)(&state->m_ctxInit[state_offset >> 2][48]), all[3]); - - for (int i = 0; i < 4; ++i) { - memset(state->m_absLevels[state_offset >> 2], 0, 16 * 4); - } + + memset(state->m_absLevels[state_offset >> 2], 0, 16 * 4); } __m128i sum_num = _mm_and_si128(last, _mm_set1_epi32(7)); @@ -973,10 +1036,8 @@ static INLINE void update_states_avx2( state->all_lt_four = rem_reg_all_lt4; if (rem_reg_all_gte_4) { - const __m128i first_byte = _mm_set1_epi32(0xff); const __m128i ones = _mm_set1_epi32(1); const uint32_t tinit_offset = MIN(level_offset - 1u, 15u); - const __m128i levels_start_offsets = _mm_set_epi32(16 * 3, 16 * 2, 16 * 1, 16 * 0); __m128i tinit = _mm_loadu_si128((__m128i*)(&state->m_ctxInit[state_offset >> 2][tinit_offset * 4])); tinit = _mm_cvtepi16_epi32(tinit); __m128i sum_abs1 = _mm_and_si128(_mm_srli_epi32(tinit, 3), _mm_set1_epi32(31)); @@ -1119,11 +1180,8 @@ static INLINE void update_states_avx2( } else if (rem_reg_all_lt4) { - const __m128i first_byte = _mm_set1_epi32(0xff); uint8_t* levels = (uint8_t*)state->m_absLevels[state_offset >> 2]; - const __m128i last_byte = _mm_set1_epi32(0xff); const uint32_t tinit_offset = MIN(level_offset - 1u, 15u); - const __m128i levels_start_offsets = _mm_set_epi32(16 * 3, 16 * 2, 16 * 1, 16 * 0); __m128i tinit = _mm_loadu_si128((__m128i*)(&state->m_ctxInit[state_offset >> 2][tinit_offset * 4])); tinit = _mm_cvtepi16_epi32(tinit); __m128i sum_abs = _mm_srli_epi32(tinit, 8);