[avx2] Add comments

This commit is contained in:
Joose Sainio 2023-05-29 10:36:18 +03:00
parent f2fb641acb
commit 254826d396

View file

@ -81,12 +81,18 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
if (state->all_gte_four) { if (state->all_gte_four) {
// pqDataA
// In case the both levels are smaller than 4 or gte 4 avx 2 can be used
if (pqDataA->absLevel[0] < 4 && pqDataA->absLevel[3] < 4) { if (pqDataA->absLevel[0] < 4 && pqDataA->absLevel[3] < 4) {
// The coeffFracBits arrays are 6 elements long, so we need to offset the indices and gather is only eficient way to load the data
__m128i offsets = _mm_set_epi32(18 + pqDataA->absLevel[3], 12 + pqDataA->absLevel[3], 6 + pqDataA->absLevel[0], 0 + pqDataA->absLevel[0]); __m128i offsets = _mm_set_epi32(18 + pqDataA->absLevel[3], 12 + pqDataA->absLevel[3], 6 + pqDataA->absLevel[0], 0 + pqDataA->absLevel[0]);
__m128i coeff_frac_bits = _mm_i32gather_epi32(&state->m_coeffFracBits[start][0], offsets, 4); __m128i coeff_frac_bits = _mm_i32gather_epi32(&state->m_coeffFracBits[start][0], offsets, 4);
// RD costs are 64 bit, so we need to extend the 32 bit values
__m256i ext_frac_bits = _mm256_cvtepi32_epi64(coeff_frac_bits); __m256i ext_frac_bits = _mm256_cvtepi32_epi64(coeff_frac_bits);
rd_cost_a = _mm256_add_epi64(rd_cost_a, ext_frac_bits); rd_cost_a = _mm256_add_epi64(rd_cost_a, ext_frac_bits);
} else if (pqDataA->absLevel[0] >= 4 && pqDataA->absLevel[3] >= 4) { }
else if (pqDataA->absLevel[0] >= 4 && pqDataA->absLevel[3] >= 4) {
__m128i value = _mm_set_epi32((pqDataA->absLevel[3] - 4) >> 1, (pqDataA->absLevel[3] - 4) >> 1, (pqDataA->absLevel[0] - 4) >> 1, (pqDataA->absLevel[0] - 4) >> 1); __m128i value = _mm_set_epi32((pqDataA->absLevel[3] - 4) >> 1, (pqDataA->absLevel[3] - 4) >> 1, (pqDataA->absLevel[0] - 4) >> 1, (pqDataA->absLevel[0] - 4) >> 1);
__m128i offsets = _mm_set_epi32(18 + pqDataA->absLevel[3], 12 + pqDataA->absLevel[3], 6 + pqDataA->absLevel[0], 0 + pqDataA->absLevel[0]); __m128i offsets = _mm_set_epi32(18 + pqDataA->absLevel[3], 12 + pqDataA->absLevel[3], 6 + pqDataA->absLevel[0], 0 + pqDataA->absLevel[0]);
@ -96,6 +102,8 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
__m128i max_rice = _mm_set1_epi32(31); __m128i max_rice = _mm_set1_epi32(31);
value = _mm_min_epi32(value, max_rice); value = _mm_min_epi32(value, max_rice);
// In the original implementation the goRiceTab is selected beforehand, but since we need to load from
// potentially four different locations, we need to calculate the offsets and use gather
__m128i go_rice_tab = _mm_cvtepi8_epi32(_mm_loadu_si32(&state->m_goRicePar[start])); __m128i go_rice_tab = _mm_cvtepi8_epi32(_mm_loadu_si32(&state->m_goRicePar[start]));
go_rice_tab = _mm_slli_epi32(go_rice_tab, 5); go_rice_tab = _mm_slli_epi32(go_rice_tab, 5);
value = _mm_add_epi32(value, go_rice_tab); value = _mm_add_epi32(value, go_rice_tab);
@ -104,7 +112,8 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
rd_cost_a = _mm256_add_epi64(rd_cost_a, _mm256_cvtepi32_epi64(temp)); rd_cost_a = _mm256_add_epi64(rd_cost_a, _mm256_cvtepi32_epi64(temp));
} else { } else {
const int pqAs[4] = {0, 0, 3, 3}; const int pqAs[4] = {0, 0, 3, 3};
ALIGNED(32) int64_t rd_costs[4] = {0, 0, 0, 0}; ALIGNED(32) int64_t rd_costs[4] = {0, 0, 0, 0};
// AVX2 cannot be used so we have to loop the values normally
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
const int state_offset = start + i; const int state_offset = start + i;
const int pqA = pqAs[i]; const int pqA = pqAs[i];
@ -119,6 +128,7 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
rd_cost_a = _mm256_add_epi64(rd_cost_a, _mm256_loadu_si256((__m256i const *)&rd_costs[0])); rd_cost_a = _mm256_add_epi64(rd_cost_a, _mm256_loadu_si256((__m256i const *)&rd_costs[0]));
} }
// pqDataB, same stuff as for pqDataA
if (pqDataA->absLevel[1] < 4 && pqDataA->absLevel[2] < 4) { if (pqDataA->absLevel[1] < 4 && pqDataA->absLevel[2] < 4) {
__m128i offsets = _mm_set_epi32(18 + pqDataA->absLevel[1], 12 + pqDataA->absLevel[1], 6 + pqDataA->absLevel[2], 0 + pqDataA->absLevel[2]); __m128i offsets = _mm_set_epi32(18 + pqDataA->absLevel[1], 12 + pqDataA->absLevel[1], 6 + pqDataA->absLevel[2], 0 + pqDataA->absLevel[2]);
__m128i coeff_frac_bits = _mm_i32gather_epi32(state->m_coeffFracBits[start], offsets, 4); __m128i coeff_frac_bits = _mm_i32gather_epi32(state->m_coeffFracBits[start], offsets, 4);
@ -159,6 +169,10 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
} }
if (spt == SCAN_ISCSBB) { if (spt == SCAN_ISCSBB) {
// This loads values such as that the values are
// |State 0 Flag 0|State 0 Flag 1|State 1 Flag 0|State 1 Flag 1|State 2 Flag 0|State 2 Flag 1|State 3 Flag 0|State 3 Flag 1|
// By setting the flag 1 bits to zero we get the flag 0 values as 64 bit integers (even) variable which we can be summed to the rd_cost
// Flag 1 values can be shifted 32 to right and again we have 64 bit integeres holding the values (odd) which can be summed to the rd_cost
__m256i original = _mm256_loadu_si256((__m256i const*)state->m_sigFracBits[start]); __m256i original = _mm256_loadu_si256((__m256i const*)state->m_sigFracBits[start]);
__m256i even = _mm256_and_si256(original, _mm256_set1_epi64x(0xffffffff)); __m256i even = _mm256_and_si256(original, _mm256_set1_epi64x(0xffffffff));
__m256i odd = _mm256_srli_epi64(original, 32); __m256i odd = _mm256_srli_epi64(original, 32);
@ -168,6 +182,7 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
} else if (spt == SCAN_SOCSBB) { } else if (spt == SCAN_SOCSBB) {
__m256i original = _mm256_loadu_si256((__m256i const*)state->m_sigFracBits[start]); __m256i original = _mm256_loadu_si256((__m256i const*)state->m_sigFracBits[start]);
// Same here
__m256i m_sigFracBits_0 = _mm256_and_si256(original, _mm256_set1_epi64x(0xffffffff)); __m256i m_sigFracBits_0 = _mm256_and_si256(original, _mm256_set1_epi64x(0xffffffff));
__m256i m_sigFracBits_1 = _mm256_srli_epi64(original, 32); __m256i m_sigFracBits_1 = _mm256_srli_epi64(original, 32);
@ -185,6 +200,7 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
else { else {
int num_sig_sbb; int num_sig_sbb;
memcpy(&num_sig_sbb, &state->m_numSigSbb[start], 4); memcpy(&num_sig_sbb, &state->m_numSigSbb[start], 4);
// numSigSbb only has values 1 or zero, so if all 4 values are 1 the complete value is 0x01010101
if (num_sig_sbb == 0x01010101) { if (num_sig_sbb == 0x01010101) {
__m256i original = _mm256_loadu_si256((__m256i const*)state->m_sigFracBits[start]); __m256i original = _mm256_loadu_si256((__m256i const*)state->m_sigFracBits[start]);
__m256i even = _mm256_and_si256(original, _mm256_set1_epi64x(0xffffffff)); __m256i even = _mm256_and_si256(original, _mm256_set1_epi64x(0xffffffff));
@ -224,25 +240,30 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
// RD cost A // RD cost A
{ {
__m128i pq_abs_a = _mm_set_epi32(pqDataA->absLevel[3], pqDataA->absLevel[3], pqDataA->absLevel[0], pqDataA->absLevel[0]); __m128i pq_abs_a = _mm_set_epi32(pqDataA->absLevel[3], pqDataA->absLevel[3], pqDataA->absLevel[0], pqDataA->absLevel[0]);
// Calculate mask for pqDataA->absLevel <= state->m_goRiceZero
// The mask is reverse of the one that is used in the scalar code so the values are in other order in blendv
__m128i cmp = _mm_cmpgt_epi32(pq_abs_a, go_rice_zero); __m128i cmp = _mm_cmpgt_epi32(pq_abs_a, go_rice_zero);
// pqDataA->absLevel < RICEMAX ? pqDataA->absLevel : RICEMAX - 1
__m128i go_rice_smaller = _mm_min_epi32(pq_abs_a, max_rice); __m128i go_rice_smaller = _mm_min_epi32(pq_abs_a, max_rice);
// pqDataA->absLevel - 1
__m128i other = _mm_sub_epi32(pq_abs_a, _mm_set1_epi32(1)); __m128i other = _mm_sub_epi32(pq_abs_a, _mm_set1_epi32(1));
__m128i selected = _mm_blendv_epi8(other, go_rice_smaller, cmp); __m128i selected = _mm_blendv_epi8(other, go_rice_smaller, cmp);
// Again calculate the offset for the different go_rice_tabs
__m128i go_rice_offset = _mm_cvtepi8_epi32(_mm_loadu_si128((__m128i*)&state->m_goRicePar[start])); __m128i go_rice_offset = _mm_cvtepi8_epi32(_mm_loadu_si128((__m128i*)&state->m_goRicePar[start]));
go_rice_offset = _mm_slli_epi32(go_rice_offset, 5); go_rice_offset = _mm_slli_epi32(go_rice_offset, 5);
__m128i offsets = _mm_add_epi32(selected, go_rice_offset); __m128i offsets = _mm_add_epi32(selected, go_rice_offset);
__m128i go_rice_tab = _mm_i32gather_epi32(&g_goRiceBits[0][0], offsets, 4); __m128i go_rice_tab = _mm_i32gather_epi32(&g_goRiceBits[0][0], offsets, 4);
//(1 << SCALE_BITS) + goRiceTab[selected]
__m128i temp = _mm_add_epi32(go_rice_tab, scale_bits); __m128i temp = _mm_add_epi32(go_rice_tab, scale_bits);
rd_cost_a = _mm256_add_epi64(rd_cost_a, _mm256_cvtepi32_epi64(temp)); rd_cost_a = _mm256_add_epi64(rd_cost_a, _mm256_cvtepi32_epi64(temp));
} }
// RD cost b // RD cost b, same as RD cost A
{ {
__m128i pq_abs_b = _mm_set_epi32(pqDataA->absLevel[1], pqDataA->absLevel[1], pqDataA->absLevel[2], pqDataA->absLevel[2]); __m128i pq_abs_b = _mm_set_epi32(pqDataA->absLevel[1], pqDataA->absLevel[1], pqDataA->absLevel[2], pqDataA->absLevel[2]);
__m128i cmp = _mm_cmpgt_epi32(pq_abs_b, go_rice_zero); __m128i cmp = _mm_cmpgt_epi32(pq_abs_b, go_rice_zero);
@ -265,6 +286,7 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
} }
// RD cost Z // RD cost Z
{ {
// This time the go_rice_tab is offset with only the go_rize_zero
__m128i go_rice_offset = _mm_cvtepi8_epi32(_mm_loadu_si128((__m128i*)&state->m_goRicePar[start])); __m128i go_rice_offset = _mm_cvtepi8_epi32(_mm_loadu_si128((__m128i*)&state->m_goRicePar[start]));
go_rice_offset = _mm_slli_epi32(go_rice_offset, 5); go_rice_offset = _mm_slli_epi32(go_rice_offset, 5);
@ -325,6 +347,7 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
rd_cost_b = _mm256_loadu_si256((__m256i*)temp_rd_cost_b); rd_cost_b = _mm256_loadu_si256((__m256i*)temp_rd_cost_b);
rd_cost_z = _mm256_loadu_si256((__m256i*)temp_rd_cost_z); rd_cost_z = _mm256_loadu_si256((__m256i*)temp_rd_cost_z);
} }
// Re order the cost so that cost of state 0 is in the first element state 1 in second etc
rd_cost_a = _mm256_permute4x64_epi64(rd_cost_a, 216); rd_cost_a = _mm256_permute4x64_epi64(rd_cost_a, 216);
rd_cost_b = _mm256_permute4x64_epi64(rd_cost_b, 141); rd_cost_b = _mm256_permute4x64_epi64(rd_cost_b, 141);
rd_cost_z = _mm256_permute4x64_epi64(rd_cost_z, 216); rd_cost_z = _mm256_permute4x64_epi64(rd_cost_z, 216);
@ -334,8 +357,9 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
__m256i decision_prev_state = _mm256_load_si256((__m256i*)decisions->prevId); __m256i decision_prev_state = _mm256_load_si256((__m256i*)decisions->prevId);
__m256i decision_data = _mm256_permute2x128_si256(decision_abs_coeff, decision_prev_state, 0x20); __m256i decision_data = _mm256_permute2x128_si256(decision_abs_coeff, decision_prev_state, 0x20);
__m256i mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); __m256i mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
decision_data = _mm256_permutevar8x32_epi32(decision_data, mask);
// Store data for all of the cost so that the lower 32 bits have coefficient magnitude and upper have the previous state
decision_data = _mm256_permutevar8x32_epi32(decision_data, mask);
__m256i a_data = _mm256_set_epi32(3, pqDataA->absLevel[3], 1, pqDataA->absLevel[0], 2, pqDataA->absLevel[3], 0, pqDataA->absLevel[0]); __m256i a_data = _mm256_set_epi32(3, pqDataA->absLevel[3], 1, pqDataA->absLevel[0], 2, pqDataA->absLevel[3], 0, pqDataA->absLevel[0]);
__m256i b_data = _mm256_set_epi32(2, pqDataA->absLevel[1], 0, pqDataA->absLevel[2], 3, pqDataA->absLevel[1], 1, pqDataA->absLevel[2]); __m256i b_data = _mm256_set_epi32(2, pqDataA->absLevel[1], 0, pqDataA->absLevel[2], 3, pqDataA->absLevel[1], 1, pqDataA->absLevel[2]);
__m256i z_data = _mm256_set_epi32(3, 0, 1, 0, 2, 0, 0, 0); __m256i z_data = _mm256_set_epi32(3, 0, 1, 0, 2, 0, 0, 0);
@ -514,6 +538,7 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
prev_state_no_offset, prev_state_no_offset,
_mm_set1_epi32(ctxs->m_prev_state_offset) _mm_set1_epi32(ctxs->m_prev_state_offset)
); );
// Set the high bytes to 0xff so that the shuffle will set them to zero and it won't cause problems with the min_epi32
__m128i prev_state_with_ff_high_bytes = _mm_or_si128(prev_state, _mm_set1_epi32(0xffffff00)); __m128i prev_state_with_ff_high_bytes = _mm_or_si128(prev_state, _mm_set1_epi32(0xffffff00));
__m128i num_sig_sbb = _mm_load_si128((const __m128i*)state->m_numSigSbb); __m128i num_sig_sbb = _mm_load_si128((const __m128i*)state->m_numSigSbb);
num_sig_sbb = _mm_shuffle_epi8(num_sig_sbb, prev_state_with_ff_high_bytes); num_sig_sbb = _mm_shuffle_epi8(num_sig_sbb, prev_state_with_ff_high_bytes);
@ -526,9 +551,12 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
int num_sig_sbb_s = _mm_extract_epi32(num_sig_sbb, 0); int num_sig_sbb_s = _mm_extract_epi32(num_sig_sbb, 0);
memcpy(&state->m_numSigSbb[state_offset], &num_sig_sbb_s, 4); memcpy(&state->m_numSigSbb[state_offset], &num_sig_sbb_s, 4);
// Set this so that the temp_prev_state has the previous state set into the first 4 bytes and duplicated to the second 4 bytes
__m128i temp_prev_state = _mm_shuffle_epi8(prev_state_no_offset, control); __m128i temp_prev_state = _mm_shuffle_epi8(prev_state_no_offset, control);
__m256i prev_state_256 = _mm256_castsi128_si256(temp_prev_state); __m256i prev_state_256 = _mm256_castsi128_si256(temp_prev_state);
// Duplicate the state all over the vector so that all 32 bytes hold the previous states
prev_state_256 = _mm256_permute4x64_epi64(prev_state_256, 0); prev_state_256 = _mm256_permute4x64_epi64(prev_state_256, 0);
// Increment the second set by four, third by eight and fourth by twelve and repeat for the second lane
__m256i temp_add = _mm256_setr_epi32(0, 0x04040404, 0x08080808, 0x0c0c0c0c, 0, 0x04040404, 0x08080808, 0x0c0c0c0c); __m256i temp_add = _mm256_setr_epi32(0, 0x04040404, 0x08080808, 0x0c0c0c0c, 0, 0x04040404, 0x08080808, 0x0c0c0c0c);
prev_state_256 = _mm256_add_epi8(prev_state_256, temp_add); prev_state_256 = _mm256_add_epi8(prev_state_256, temp_add);
for (int i = 0; i < 64; i += (256 / (8 * sizeof(uint8_t)))) { for (int i = 0; i < 64; i += (256 / (8 * sizeof(uint8_t)))) {
@ -537,6 +565,8 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
_mm256_store_si256((__m256i*)&state->m_absLevels[ctxs->m_curr_state_offset >> 2][i], data); _mm256_store_si256((__m256i*)&state->m_absLevels[ctxs->m_curr_state_offset >> 2][i], data);
} }
} else { } else {
// TODO: it would be possible to do the absLevels update with avx2 even here just would need to set the shuffle mask to
// 0xff for the states that don't have previous state or the previous state is a skip state
int prev_state_s[4] = {-1, -1, -1, -1}; int prev_state_s[4] = {-1, -1, -1, -1};
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
const int decision_id = i; const int decision_id = i;
@ -584,14 +614,18 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
if (all_have_previous_state) { if (all_have_previous_state) {
__m128i temp_p_state = _mm_shuffle_epi8(prev_state, control); __m128i temp_p_state = _mm_shuffle_epi8(prev_state, control);
__m128i ref_sbb_ctx_offset = // Similarly to how the abs level was done earlier set the previous state duplicated across the lane
_mm_load_si128((__m128i*)ctxs->m_allStates.m_refSbbCtxId); __m128i ref_sbb_ctx_offset = _mm_load_si128((__m128i*)ctxs->m_allStates.m_refSbbCtxId);
ref_sbb_ctx_offset = _mm_shuffle_epi8(ref_sbb_ctx_offset, temp_p_state); ref_sbb_ctx_offset = _mm_shuffle_epi8(ref_sbb_ctx_offset, temp_p_state);
// numSbb is two or four, in case it is one this function is never called
if (numSbb <= 4) { if (numSbb <= 4) {
__m128i incremented_ref_sbb_ctx_offset = _mm_add_epi8( __m128i incremented_ref_sbb_ctx_offset = _mm_add_epi8(
ref_sbb_ctx_offset, ref_sbb_ctx_offset,
_mm_setr_epi8(0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12) _mm_setr_epi8(0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12)
); );
// In case the ref_sbb_ctx is minus one the values need to be set to zero, which is achieved by
// first finding which states have the minus one and then the blend is used after the load to
// set the corresponding values to zero
__m128i blend_mask = _mm_cmpeq_epi8(ref_sbb_ctx_offset, _mm_set1_epi32(0xffffffff)); __m128i blend_mask = _mm_cmpeq_epi8(ref_sbb_ctx_offset, _mm_set1_epi32(0xffffffff));
__m128i sbb_flags = _mm_loadu_si128((__m128i*)cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset].sbbFlags); __m128i sbb_flags = _mm_loadu_si128((__m128i*)cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset].sbbFlags);
sbb_flags = _mm_shuffle_epi8(sbb_flags, incremented_ref_sbb_ctx_offset); sbb_flags = _mm_shuffle_epi8(sbb_flags, incremented_ref_sbb_ctx_offset);
@ -609,6 +643,10 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
extended_ref_state, extended_ref_state,
_mm256_setr_epi32(0, 0x04040404, 0x08080808, 0x0c0c0c0c,0, 0x04040404, 0x08080808, 0x0c0c0c0c) _mm256_setr_epi32(0, 0x04040404, 0x08080808, 0x0c0c0c0c,0, 0x04040404, 0x08080808, 0x0c0c0c0c)
); );
// Unlike the case for two or four sbb, the blendv is used to set the shuffle mask to -1 so that
// the shuffle will set the values to zero. Its better to do this way here so that the blendv is
// not called in the loop, and the other is done the otherway because I implemented it first
// and only realized afterwards that this order is better
__m256i blend_mask = _mm256_cmpeq_epi8(extended_ref_state, _mm256_set1_epi32(0xffffffff)); __m256i blend_mask = _mm256_cmpeq_epi8(extended_ref_state, _mm256_set1_epi32(0xffffffff));
inc_ref_state = _mm256_blendv_epi8(inc_ref_state, _mm256_set1_epi32(0xffffffff), blend_mask); inc_ref_state = _mm256_blendv_epi8(inc_ref_state, _mm256_set1_epi32(0xffffffff), blend_mask);
for (int i = 0; i < numSbb * 4; i += 32) { for (int i = 0; i < numSbb * 4; i += 32) {
@ -617,9 +655,12 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
_mm256_store_si256((__m256i*)&sbbFlags[i], sbb_flags); _mm256_store_si256((__m256i*)&sbbFlags[i], sbb_flags);
} }
} }
// The first 16 variables will be loaded from the previous state so this can be started from 16
int levels_start = 16; int levels_start = 16;
// Do avx2 optimized version for the amount that is divisible by 8 (four states of 8 1-byte values)
const uint64_t limit = setCpSize & ~(8 - 1); const uint64_t limit = setCpSize & ~(8 - 1);
if (levels_start < limit) { if (levels_start < limit) {
// Overall this is the same to the numSbb > 4
__m256i extended_ref_state = _mm256_zextsi128_si256(ref_sbb_ctx_offset); __m256i extended_ref_state = _mm256_zextsi128_si256(ref_sbb_ctx_offset);
extended_ref_state = _mm256_permute4x64_epi64(extended_ref_state, 0); extended_ref_state = _mm256_permute4x64_epi64(extended_ref_state, 0);
__m256i inc_ref_state = _mm256_add_epi8( __m256i inc_ref_state = _mm256_add_epi8(
@ -637,6 +678,7 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
uint8_t ref_sbb[4]; uint8_t ref_sbb[4];
int temp_sbb_ref = _mm_extract_epi32(ref_sbb_ctx_offset, 0); int temp_sbb_ref = _mm_extract_epi32(ref_sbb_ctx_offset, 0);
memcpy(ref_sbb, &temp_sbb_ref, 4); memcpy(ref_sbb, &temp_sbb_ref, 4);
// Do the excess that is not divisible by 8
for (;levels_start < setCpSize; ++levels_start) { for (;levels_start < setCpSize; ++levels_start) {
uint8_t new_values[4]; uint8_t new_values[4];
new_values[0] = ref_sbb[0] != 0xff ? levels_in[levels_start * 4 + ref_sbb[0]] : 0; new_values[0] = ref_sbb[0] != 0xff ? levels_in[levels_start * 4 + ref_sbb[0]] : 0;
@ -648,6 +690,8 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
} }
else { else {
//TODO: This could also be done using avx2 just need to check for both wheter the previous state
// is minus one and that if the ref_sbb_ctx_id is minus one.
for (int curr_state = 0; curr_state < 4; ++curr_state) { for (int curr_state = 0; curr_state < 4; ++curr_state) {
const int p_state = previous_state_array[curr_state]; const int p_state = previous_state_array[curr_state];
if (p_state != -1 && ctxs->m_allStates.m_refSbbCtxId[p_state] >= 0) { if (p_state != -1 && ctxs->m_allStates.m_refSbbCtxId[p_state] >= 0) {
@ -681,6 +725,8 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
__m128i sig_sbb = _mm_or_si128(sbb_right, sbb_below); __m128i sig_sbb = _mm_or_si128(sbb_right, sbb_below);
sig_sbb = _mm_min_epi32(sig_sbb, _mm_set1_epi32(1)); sig_sbb = _mm_min_epi32(sig_sbb, _mm_set1_epi32(1));
// Gather is not necessary here put it would require at least five operation to do the same thing
// so the performance gain in my opinion is not worth the readability loss
__m256i sbb_frac_bits = _mm256_i32gather_epi64((int64_t *)cc->m_sbbFlagBits[0], sig_sbb, 8); __m256i sbb_frac_bits = _mm256_i32gather_epi64((int64_t *)cc->m_sbbFlagBits[0], sig_sbb, 8);
_mm256_store_si256((__m256i*)state->m_sbbFracBits[state_offset], sbb_frac_bits); _mm256_store_si256((__m256i*)state->m_sbbFracBits[state_offset], sbb_frac_bits);
@ -806,6 +852,7 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
memset(state->m_absLevels[state_offset >> 2], 0, 16 * 4); memset(state->m_absLevels[state_offset >> 2], 0, 16 * 4);
} }
// End update common context
__m128i sum_num = _mm_and_si128(last, _mm_set1_epi32(7)); __m128i sum_num = _mm_and_si128(last, _mm_set1_epi32(7));
__m128i sum_abs1 = _mm_and_si128( __m128i sum_abs1 = _mm_and_si128(
@ -829,6 +876,7 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
__m128i min_gt1 = _mm_min_epi32(sum_gt1, _mm_set1_epi32(4)); __m128i min_gt1 = _mm_min_epi32(sum_gt1, _mm_set1_epi32(4));
uint32_t sum_gt1_s[4]; uint32_t sum_gt1_s[4];
_mm_storeu_si128((__m128i*)sum_gt1_s, min_gt1); _mm_storeu_si128((__m128i*)sum_gt1_s, min_gt1);
// These are 192 bits so no benefit from using avx2
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
memcpy(state->m_coeffFracBits[state_offset + i], state->m_gtxFracBitsArray[sum_gt1_s[i] + gtxCtxOffsetNext], sizeof(state->m_coeffFracBits[0])); memcpy(state->m_coeffFracBits[state_offset + i], state->m_gtxFracBitsArray[sum_gt1_s[i] + gtxCtxOffsetNext], sizeof(state->m_coeffFracBits[0]));
} }
@ -887,7 +935,9 @@ static INLINE void update_states_avx2(
__m128i prev_offset = _mm_set1_epi32(ctxs->m_prev_state_offset); __m128i prev_offset = _mm_set1_epi32(ctxs->m_prev_state_offset);
__m128i prv_states = _mm_add_epi32(prv_states_o, prev_offset); __m128i prv_states = _mm_add_epi32(prv_states_o, prev_offset);
__m128i shuffled_prev_states = _mm_shuffle_epi8(prv_states, control); __m128i shuffled_prev_states = _mm_shuffle_epi8(prv_states, control);
// sig_sbb values matter only whether they are one or zero so make sure that they stay at one or zero
// which allows some optimizations when handling the values in update_state_eos_avx2
__m128i sig_sbb = _mm_load_si128((__m128i const*)state->m_numSigSbb); __m128i sig_sbb = _mm_load_si128((__m128i const*)state->m_numSigSbb);
sig_sbb = _mm_shuffle_epi8(sig_sbb, shuffled_prev_states); sig_sbb = _mm_shuffle_epi8(sig_sbb, shuffled_prev_states);
__m128i has_coeff = _mm_min_epi32(abs_level, _mm_set1_epi32(1)); __m128i has_coeff = _mm_min_epi32(abs_level, _mm_set1_epi32(1));
@ -895,7 +945,8 @@ static INLINE void update_states_avx2(
sig_sbb = _mm_or_si128(sig_sbb, has_coeff); sig_sbb = _mm_or_si128(sig_sbb, has_coeff);
int sig_sbb_i = _mm_extract_epi32(sig_sbb, 0); int sig_sbb_i = _mm_extract_epi32(sig_sbb, 0);
memcpy(&state->m_numSigSbb[state_offset], &sig_sbb_i, 4); memcpy(&state->m_numSigSbb[state_offset], &sig_sbb_i, 4);
// These following two are jus shuffled and then extracted the 4 bytes that store the values
__m128i ref_sbb_ctx_idx = _mm_load_si128((__m128i const*)state->m_refSbbCtxId); __m128i ref_sbb_ctx_idx = _mm_load_si128((__m128i const*)state->m_refSbbCtxId);
ref_sbb_ctx_idx = _mm_shuffle_epi8(ref_sbb_ctx_idx, shuffled_prev_states); ref_sbb_ctx_idx = _mm_shuffle_epi8(ref_sbb_ctx_idx, shuffled_prev_states);
int ref_sbb_ctx = _mm_extract_epi32(ref_sbb_ctx_idx, 0); int ref_sbb_ctx = _mm_extract_epi32(ref_sbb_ctx_idx, 0);
@ -906,23 +957,30 @@ static INLINE void update_states_avx2(
int go_rice_par_i = _mm_extract_epi32(go_rice_par, 0); int go_rice_par_i = _mm_extract_epi32(go_rice_par, 0);
memcpy(&state->m_goRicePar[state_offset], &go_rice_par_i, 4); memcpy(&state->m_goRicePar[state_offset], &go_rice_par_i, 4);
// Again gather is not necessary but it is easier to read and shouldn't have too large of a performance hit
// Should be true for all gathers here
__m256i sbb_frac_bits = _mm256_i32gather_epi64((const int64_t *)state->m_sbbFracBits[0], prv_states, 8); __m256i sbb_frac_bits = _mm256_i32gather_epi64((const int64_t *)state->m_sbbFracBits[0], prv_states, 8);
_mm256_store_si256((__m256i*)&state->m_sbbFracBits[state_offset][0], sbb_frac_bits); _mm256_store_si256((__m256i*)&state->m_sbbFracBits[state_offset][0], sbb_frac_bits);
// Next three lines: state->m_remRegBins = prvState->m_remRegBins - 1;
__m128i rem_reg_bins = _mm_i32gather_epi32(state->m_remRegBins, prv_states, 4); __m128i rem_reg_bins = _mm_i32gather_epi32(state->m_remRegBins, prv_states, 4);
__m128i ones = _mm_set1_epi32(1); __m128i ones = _mm_set1_epi32(1);
rem_reg_bins = _mm_sub_epi32(rem_reg_bins, ones); rem_reg_bins = _mm_sub_epi32(rem_reg_bins, ones);
__m128i reg_bins_sub = _mm_set1_epi32(0); __m128i reg_bins_sub = _mm_set1_epi32(0);
// Next two lines: (decision->absLevel < 2 ? (unsigned)decision->absLevel : 3)
__m128i abs_level_smaller_than_two = _mm_cmplt_epi32(abs_level, _mm_set1_epi32(2)); __m128i abs_level_smaller_than_two = _mm_cmplt_epi32(abs_level, _mm_set1_epi32(2));
__m128i secondary = _mm_blendv_epi8(_mm_set1_epi32(3), abs_level, abs_level_smaller_than_two); __m128i secondary = _mm_blendv_epi8(_mm_set1_epi32(3), abs_level, abs_level_smaller_than_two);
// Depending on whether the rem_reg_bins are smaller than four or not,
// the reg_bins_sub is either 0 or result of the above operation
__m128i rem_reg_bins_smaller_than_four = _mm_cmplt_epi32(rem_reg_bins, _mm_set1_epi32(4)); __m128i rem_reg_bins_smaller_than_four = _mm_cmplt_epi32(rem_reg_bins, _mm_set1_epi32(4));
reg_bins_sub = _mm_blendv_epi8(secondary, reg_bins_sub, rem_reg_bins_smaller_than_four); reg_bins_sub = _mm_blendv_epi8(secondary, reg_bins_sub, rem_reg_bins_smaller_than_four);
rem_reg_bins = _mm_sub_epi32(rem_reg_bins, reg_bins_sub); rem_reg_bins = _mm_sub_epi32(rem_reg_bins, reg_bins_sub);
_mm_store_si128((__m128i*)&state->m_remRegBins[state_offset], rem_reg_bins); _mm_store_si128((__m128i*)&state->m_remRegBins[state_offset], rem_reg_bins);
// Save whether all rem_reg_bins are smaller than four or not and gte 4 as these
// are needed in multiple places
__m128i mask = _mm_cmpgt_epi32(rem_reg_bins, _mm_set1_epi32(3)); __m128i mask = _mm_cmpgt_epi32(rem_reg_bins, _mm_set1_epi32(3));
int bit_mask = _mm_movemask_epi8(mask); int bit_mask = _mm_movemask_epi8(mask);
rem_reg_all_gte_4 = (bit_mask == 0xFFFF); rem_reg_all_gte_4 = (bit_mask == 0xFFFF);
@ -930,7 +988,7 @@ static INLINE void update_states_avx2(
bit_mask = _mm_movemask_epi8(mask); bit_mask = _mm_movemask_epi8(mask);
rem_reg_all_lt4 = (bit_mask == 0xFFFF); rem_reg_all_lt4 = (bit_mask == 0xFFFF);
// This is the same as in update_state_eos_avx2
__m128i temp_prev_state = _mm_shuffle_epi8(prv_states_o, control); __m128i temp_prev_state = _mm_shuffle_epi8(prv_states_o, control);
__m256i prev_state_256 = _mm256_castsi128_si256(temp_prev_state); __m256i prev_state_256 = _mm256_castsi128_si256(temp_prev_state);
prev_state_256 = _mm256_permute4x64_epi64(prev_state_256, 0); prev_state_256 = _mm256_permute4x64_epi64(prev_state_256, 0);
@ -950,15 +1008,21 @@ static INLINE void update_states_avx2(
_mm256_store_si256((__m256i*)&state->m_absLevels[ctxs->m_curr_state_offset >> 2][i], data); _mm256_store_si256((__m256i*)&state->m_absLevels[ctxs->m_curr_state_offset >> 2][i], data);
} }
// This is overall the same as absLevels but since the ctx values are two bytes all of the
// masks have to account for that
__m256i prev_state_full = _mm256_load_si256((__m256i const*)decisions->prevId); __m256i prev_state_full = _mm256_load_si256((__m256i const*)decisions->prevId);
__m256i shuffle_mask = _mm256_setr_epi8(0, 0, 4, 4,8, 8, 12, 12, 0, 0, 4, 4, 8, 8, 12, 12,0, 0, 0, 0,0, 0, 0, 0,16, 16, 16, 16, 16, 16, 16, 16); __m256i shuffle_mask = _mm256_setr_epi8(0, 0, 4, 4,8, 8, 12, 12, 0, 0, 4, 4, 8, 8, 12, 12,0, 0, 0, 0,0, 0, 0, 0,16, 16, 16, 16, 16, 16, 16, 16);
prev_state_full = _mm256_shuffle_epi8(prev_state_full, shuffle_mask); prev_state_full = _mm256_shuffle_epi8(prev_state_full, shuffle_mask);
prev_state_full = _mm256_permute4x64_epi64(prev_state_full, 0); prev_state_full = _mm256_permute4x64_epi64(prev_state_full, 0);
prev_state_full = _mm256_slli_epi16(prev_state_full, 1); prev_state_full = _mm256_slli_epi16(prev_state_full, 1);
temp_add = _mm256_setr_epi8(0, 1, 0, 1, 0, 1, 0, 1, 8, 9, 8, 9, 8, 9, 8, 9, 16, 17, 16, 17,16, 17,16, 17, 24, 25,24,25,24,25,24,25); temp_add = _mm256_setr_epi8(
0, 1, 0, 1, 0, 1, 0, 1,
8, 9, 8, 9, 8, 9, 8, 9,
16, 17, 16, 17, 16, 17, 16, 17,
24, 25, 24, 25, 24, 25, 24, 25);
prev_state_full = _mm256_add_epi8(prev_state_full, temp_add); prev_state_full = _mm256_add_epi8(prev_state_full, temp_add);
for (int i = 0; i < 64; i += (256 / 8 / sizeof(uint16_t))) { for (int i = 0; i < 64; i += (256 / (8 * sizeof(uint16_t)))) {
__m256i data = _mm256_load_si256((__m256i*)(&state->m_ctxInit[(ctxs->m_prev_state_offset >> 2)][i])); __m256i data = _mm256_load_si256((__m256i*)(&state->m_ctxInit[(ctxs->m_prev_state_offset >> 2)][i]));
data = _mm256_shuffle_epi8(data, prev_state_full); data = _mm256_shuffle_epi8(data, prev_state_full);
_mm256_store_si256((__m256i*)(&state->m_ctxInit[(state_offset >> 2)][i]), data); _mm256_store_si256((__m256i*)(&state->m_ctxInit[(state_offset >> 2)][i]), data);
@ -1016,6 +1080,7 @@ static INLINE void update_states_avx2(
rem_reg_all_lt4 &= state->m_remRegBins[state_id] < 4; rem_reg_all_lt4 &= state->m_remRegBins[state_id] < 4;
} }
{ {
// Same as for the all_non_negative but use blendv to set the shuffle mask to -1 for the states that do not have previous state
__m256i prev_state_full = _mm256_load_si256((__m256i const*)decisions->prevId); __m256i prev_state_full = _mm256_load_si256((__m256i const*)decisions->prevId);
__m256i shuffle_mask = _mm256_setr_epi8(0, 4, 8, 12, 0, 4, 8, 12, 0, 4, 8, 12, 0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); __m256i shuffle_mask = _mm256_setr_epi8(0, 4, 8, 12, 0, 4, 8, 12, 0, 4, 8, 12, 0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1);
prev_state_full = _mm256_shuffle_epi8(prev_state_full, shuffle_mask); prev_state_full = _mm256_shuffle_epi8(prev_state_full, shuffle_mask);
@ -1198,6 +1263,7 @@ static INLINE void update_states_avx2(
if (extRiceFlag) { if (extRiceFlag) {
assert(0 && "Not implemented for avx2"); assert(0 && "Not implemented for avx2");
} else { } else {
// int sumAll = MAX(MIN(31, (int)sumAbs - 4 * 5), 0);
__m128i sum_all = _mm_max_epi32( __m128i sum_all = _mm_max_epi32(
_mm_min_epi32( _mm_min_epi32(
_mm_set1_epi32(31), _mm_set1_epi32(31),
@ -1257,7 +1323,7 @@ static INLINE void update_states_avx2(
int go_rice_par_i = _mm_extract_epi32(go_rice_par, 0); int go_rice_par_i = _mm_extract_epi32(go_rice_par, 0);
memcpy(&state->m_goRicePar[state_offset], &go_rice_par_i, 4); memcpy(&state->m_goRicePar[state_offset], &go_rice_par_i, 4);
// This cannot be vectorized because there is no way to dynamically shift values
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
state->m_goRiceZero[state_offset + i] = (i < 2 ? 1 : 2) << state->m_goRicePar[state_offset + i]; state->m_goRiceZero[state_offset + i] = (i < 2 ? 1 : 2) << state->m_goRicePar[state_offset + i];
} }