[avx2] Replace loads and stores with non-avx512 stores

This commit is contained in:
Joose Sainio 2023-04-19 13:07:38 +03:00
parent 6d0a3fa5fc
commit 8b1d6fab59

View file

@ -76,9 +76,9 @@ typedef struct
typedef struct
{
int64_t rdCost[8];
int32_t absLevel[8];
int32_t prevId[8];
int64_t ALIGNED(32) rdCost[8];
int32_t ALIGNED(32) absLevel[8];
int32_t ALIGNED(32) prevId[8];
} Decision;
@ -118,19 +118,19 @@ typedef struct
typedef struct
{
int64_t m_rdCost[12];
uint16_t m_absLevelsAndCtxInit[12][24]; // 16x8bit for abs levels + 16x16bit for ctx init id
int8_t m_numSigSbb[12];
int m_remRegBins[12];
int8_t m_refSbbCtxId[12];
uint32_t m_sbbFracBits[12][2];
uint32_t m_sigFracBits[12][2];
int32_t m_coeffFracBits[12][6];
int8_t m_goRicePar[12];
int8_t m_goRiceZero[12];
int8_t m_stateId[12];
uint32_t m_sigFracBitsArray[12][12][2];
int32_t m_gtxFracBitsArray[21][6];
int64_t ALIGNED(32) m_rdCost[12];
uint16_t ALIGNED(32) m_absLevelsAndCtxInit[12][24]; // 16x8bit for abs levels + 16x16bit for ctx init id
int8_t ALIGNED(16) m_numSigSbb[12];
int ALIGNED(32) m_remRegBins[12];
int8_t ALIGNED(16) m_refSbbCtxId[12];
uint32_t ALIGNED(32) m_sbbFracBits[12][2];
uint32_t ALIGNED(32) m_sigFracBits[12][2];
int32_t ALIGNED(32) m_coeffFracBits[12][6];
int8_t ALIGNED(16) m_goRicePar[12];
int8_t ALIGNED(16) m_goRiceZero[12];
int8_t ALIGNED(16) m_stateId[12];
uint32_t ALIGNED(32) m_sigFracBitsArray[12][12][2];
int32_t ALIGNED(32) m_gtxFracBitsArray[21][6];
common_context* m_commonCtx;
unsigned effWidth;
@ -715,10 +715,10 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
}
else {
const int pqAs[4] = {0, 0, 3, 3};
_mm256_storeu_epi64(temp_rd_cost_a, rd_cost_a);
_mm256_storeu_epi64(temp_rd_cost_b, rd_cost_b);
_mm256_storeu_epi64(temp_rd_cost_z, rd_cost_z);
const int ALIGNED(32) pqAs[4] = {0, 0, 3, 3};
_mm256_store_si256((__m256i*)temp_rd_cost_a, rd_cost_a);
_mm256_store_si256((__m256i*)temp_rd_cost_b, rd_cost_b);
_mm256_store_si256((__m256i*)temp_rd_cost_z, rd_cost_z);
for (int i = 0; i < 4; i++) {
const int state_offset = start + i;
if (state->m_numSigSbb[state_offset]) {
@ -729,15 +729,15 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
temp_rd_cost_z[i] = decisions->rdCost[pqAs[i]];
}
}
rd_cost_a = _mm256_loadu_epi64(temp_rd_cost_a);
rd_cost_b = _mm256_loadu_epi64(temp_rd_cost_b);
rd_cost_z = _mm256_loadu_epi64(temp_rd_cost_z);
rd_cost_a = _mm256_loadu_si256((__m256i*)temp_rd_cost_a);
rd_cost_b = _mm256_loadu_si256((__m256i*)temp_rd_cost_b);
rd_cost_z = _mm256_loadu_si256((__m256i*)temp_rd_cost_z);
}
}
} else if (state->all_lt_four) {
__m128i scale_bits = _mm_set1_epi32(1 << SCALE_BITS);
__m128i max_rice = _mm_set1_epi32(31);
__m128i go_rice_zero = _mm_cvtepi8_epi32(_mm_loadu_epi8(&state->m_goRiceZero[start]));
__m128i go_rice_zero = _mm_cvtepi8_epi32(_mm_loadu_si128((const __m128i*)&state->m_goRiceZero[start]));
// RD cost A
{
__m128i pq_abs_a = _mm_set_epi32(pqDataA->absLevel[3], pqDataA->absLevel[3], pqDataA->absLevel[0], pqDataA->absLevel[0]);
@ -750,7 +750,7 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
__m128i selected = _mm_blendv_epi8(other, go_rice_smaller, cmp);
__m128i go_rice_offset = _mm_cvtepi8_epi32(_mm_loadu_si32(&state->m_goRicePar[start]));
__m128i go_rice_offset = _mm_cvtepi8_epi32(_mm_loadu_si128((__m128i*)&state->m_goRicePar[start]));
go_rice_offset = _mm_slli_epi32(go_rice_offset, 5);
__m128i offsets = _mm_add_epi32(selected, go_rice_offset);
@ -771,7 +771,7 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
__m128i selected = _mm_blendv_epi8(other, go_rice_smaller, cmp);
__m128i go_rice_offset = _mm_cvtepi8_epi32(_mm_loadu_si32(&state->m_goRicePar[start]));
__m128i go_rice_offset = _mm_cvtepi8_epi32(_mm_loadu_si128((__m128i*)&state->m_goRicePar[start]));
go_rice_offset = _mm_slli_epi32(go_rice_offset, 5);
__m128i offsets = _mm_add_epi32(selected, go_rice_offset);
@ -782,7 +782,7 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
}
// RD cost Z
{
__m128i go_rice_offset = _mm_cvtepi8_epi32(_mm_loadu_si32(&state->m_goRicePar[start]));
__m128i go_rice_offset = _mm_cvtepi8_epi32(_mm_loadu_si128((__m128i*)&state->m_goRicePar[start]));
go_rice_offset = _mm_slli_epi32(go_rice_offset, 5);
go_rice_offset = _mm_add_epi32(go_rice_offset, go_rice_zero);
@ -838,17 +838,17 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
temp_rd_cost_b[i] = rdCostB;
temp_rd_cost_z[i] = rdCostZ;
}
rd_cost_a = _mm256_loadu_epi64(temp_rd_cost_a);
rd_cost_b = _mm256_loadu_epi64(temp_rd_cost_b);
rd_cost_z = _mm256_loadu_epi64(temp_rd_cost_z);
rd_cost_a = _mm256_loadu_si256((__m256i*)temp_rd_cost_a);
rd_cost_b = _mm256_loadu_si256((__m256i*)temp_rd_cost_b);
rd_cost_z = _mm256_loadu_si256((__m256i*)temp_rd_cost_z);
}
rd_cost_a = _mm256_permute4x64_epi64(rd_cost_a, 216);
rd_cost_b = _mm256_permute4x64_epi64(rd_cost_b, 141);
rd_cost_z = _mm256_permute4x64_epi64(rd_cost_z, 216);
__m256i rd_cost_decision = _mm256_loadu_epi64(decisions->rdCost);
__m256i rd_cost_decision = _mm256_load_si256((__m256i*)decisions->rdCost);
__m256i decision_abs_coeff = _mm256_loadu_epi32(decisions->absLevel);
__m256i decision_prev_state = _mm256_loadu_epi32(decisions->prevId);
__m256i decision_abs_coeff = _mm256_load_si256((__m256i*)decisions->absLevel);
__m256i decision_prev_state = _mm256_load_si256((__m256i*)decisions->prevId);
__m256i decision_data = _mm256_permute2x128_si256(decision_abs_coeff, decision_prev_state, 0x20);
__m256i mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
decision_data = _mm256_permutevar8x32_epi32(decision_data, mask);
@ -869,7 +869,7 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
__m256i final_rd_cost = _mm256_blendv_epi8(cheaper_first, cheaper_second, final_decision);
__m256i final_data = _mm256_blendv_epi8(cheaper_first_data, cheaper_second_data, final_decision);
_mm256_storeu_epi64(decisions->rdCost, final_rd_cost);
_mm256_store_si256((__m256i*)decisions->rdCost, final_rd_cost);
final_data = _mm256_permutevar8x32_epi32(final_data, _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0));
_mm256_storeu2_m128i((__m128i *)decisions->prevId, (__m128i *)decisions->absLevel, final_data);
}
@ -1172,8 +1172,8 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
int state_offset = ctxs->m_curr_state_offset;
__m256i rd_cost = _mm256_loadu_epi64(decisions->rdCost);
_mm256_storeu_epi64(&ctxs->m_allStates.m_rdCost[state_offset], rd_cost);
__m256i rd_cost = _mm256_load_si256((__m256i const*)decisions->rdCost);
_mm256_store_si256((__m256i *)& ctxs->m_allStates.m_rdCost[state_offset], rd_cost);
for (int i = 0; i < 4; ++i) {
all_above_minus_two &= decisions->prevId[i] > -2;
all_between_zero_and_three &= decisions->prevId[i] >= 0 && decisions->prevId[i] < 4;
@ -1183,10 +1183,10 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
bool all_have_previous_state = true;
__m128i prev_state;
__m128i prev_state_no_offset;
__m128i abs_level = _mm_loadu_epi32(decisions->absLevel);
__m128i abs_level = _mm_load_si128((const __m128i*)decisions->absLevel);
if (all_above_four) {
prev_state = _mm_set1_epi32(ctxs->m_skip_state_offset);
prev_state_no_offset = _mm_sub_epi32(_mm_loadu_epi32(decisions->prevId), _mm_set1_epi32(4));
prev_state_no_offset = _mm_sub_epi32(_mm_load_si128((const __m128i*)decisions->prevId), _mm_set1_epi32(4));
prev_state = _mm_add_epi32(
prev_state,
prev_state_no_offset
@ -1199,11 +1199,11 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
prev_state_no_offset = _mm_set1_epi32(ctxs->m_prev_state_offset);
prev_state = _mm_add_epi32(
prev_state_no_offset,
_mm_loadu_epi32(decisions->prevId)
_mm_load_si128((const __m128i*)decisions->prevId)
);
__m128i control = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1);
__m128i prev_state_with_ff_high_bytes = _mm_or_epi32(prev_state, _mm_set1_epi32(0xffffff00));
__m128i num_sig_sbb = _mm_loadu_epi32(state->m_numSigSbb);
__m128i num_sig_sbb = _mm_load_si128((const __m128i*)state->m_numSigSbb);
num_sig_sbb = _mm_shuffle_epi8(num_sig_sbb, prev_state_with_ff_high_bytes);
num_sig_sbb = _mm_add_epi32(
num_sig_sbb,
@ -1215,24 +1215,21 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
memcpy(&state->m_numSigSbb[state_offset], &num_sig_sbb_s, 4);
int32_t prev_state_scalar[4];
_mm_storeu_epi32(prev_state_scalar, prev_state);
_mm_storeu_si128((__m128i*)prev_state_scalar, prev_state);
for (int i = 0; i < 4; ++i) {
memcpy(state->m_absLevelsAndCtxInit[state_offset + i], state->m_absLevelsAndCtxInit[prev_state_scalar[i]], 16 * sizeof(uint8_t));
}
} else {
int prev_state_s[4] = {-1, -1, -1, -1};
int prev_state_no_offset_s[4] = {-1, -1, -1, -1};
for (int i = 0; i < 4; ++i) {
const int decision_id = i;
const int curr_state_offset = state_offset + i;
if (decisions->prevId[decision_id] >= 4) {
prev_state_s[i] = ctxs->m_skip_state_offset + (decisions->prevId[decision_id] - 4);
prev_state_no_offset_s[i] = decisions->prevId[decision_id] - 4;
state->m_numSigSbb[curr_state_offset] = 0;
memset(state->m_absLevelsAndCtxInit[curr_state_offset], 0, 16 * sizeof(uint8_t));
} else if (decisions->prevId[decision_id] >= 0) {
prev_state_s[i] = ctxs->m_prev_state_offset + decisions->prevId[decision_id];
prev_state_no_offset_s[i] = decisions->prevId[decision_id];
state->m_numSigSbb[curr_state_offset] = state->m_numSigSbb[prev_state_s[i]] + !!decisions->absLevel[decision_id];
memcpy(state->m_absLevelsAndCtxInit[curr_state_offset], state->m_absLevelsAndCtxInit[prev_state_s[i]], 16 * sizeof(uint8_t));
} else {
@ -1241,13 +1238,12 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
all_have_previous_state = false;
}
}
prev_state = _mm_loadu_epi32(prev_state_s);
prev_state_no_offset = _mm_loadu_epi32(prev_state_no_offset_s);
prev_state = _mm_loadu_si128((__m128i const*)prev_state_s);
}
uint32_t level_offset = scan_pos & 15;
__m128i max_abs = _mm_min_epi32(abs_level, _mm_set1_epi32(32));
uint32_t max_abs_s[4];
_mm_storeu_epi32(max_abs_s, max_abs);
_mm_storeu_si128((__m128i*)max_abs_s, max_abs);
for (int i = 0; i < 4; ++i) {
uint8_t* levels = (uint8_t*)state->m_absLevelsAndCtxInit[state_offset + i];
levels[level_offset] = max_abs_s[i];
@ -1260,7 +1256,7 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
common_context* cc = &ctxs->m_common_context;
size_t setCpSize = cc->m_nbInfo[scan_pos - 1].maxDist * sizeof(uint8_t);
int previous_state_array[4];
_mm_storeu_epi32(previous_state_array, prev_state);
_mm_storeu_si128((__m128i*)previous_state_array, prev_state);
for (int curr_state = 0; curr_state < 4; ++curr_state) {
uint8_t* sbbFlags = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset + (curr_state)].sbbFlags;
uint8_t* levels = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset + (curr_state)].levels;
@ -1288,17 +1284,8 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
__m128i sig_sbb = _mm_or_epi32(sbb_right, sbb_below);
sig_sbb = _mm_and_si128(sig_sbb, _mm_set1_epi32(0xff));
sig_sbb = _mm_min_epi32(sig_sbb, _mm_set1_epi32(1));
//__m256i sig_sbb_mask = _mm256_cvtepi32_epi64(sig_sbb);
//const __m256i duplication_mask = _mm256_setr_epi8(
// 0, 0, 0, 0, 0, 0, 0, 0,
// 1, 1, 1, 1, 1, 1, 1, 1,
// 2, 2, 2, 2, 2, 2, 2, 2,
// 3, 3, 3, 3, 3, 3, 3, 3);
//sig_sbb_mask = _mm256_shuffle_epi8(sig_sbb_mask, duplication_mask);
__m256i sbb_frac_bits = _mm256_i32gather_epi64((int64_t *)cc->m_sbbFlagBits[0], sig_sbb, 8);
//__m256i sbb_frac_bits = _mm256_loadu_epi64(cc->m_sbbFlagBits);
//sbb_frac_bits = _mm256_shu
_mm256_storeu_epi64(state->m_sbbFracBits[state_offset], sbb_frac_bits);
_mm256_store_si256((__m256i*)state->m_sbbFracBits[state_offset], sbb_frac_bits);
memset(&state->m_numSigSbb[state_offset], 0, 4);
memset(&state->m_goRicePar[state_offset], 0, 4);
@ -1307,10 +1294,7 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
memcpy(&state->m_refSbbCtxId[state_offset], states, 4);
if (all_have_previous_state) {
__m128i rem_reg_bins = _mm_i32gather_epi32(state->m_remRegBins, prev_state, 4);
//prev_state_no_offset = _mm_shuffle_epi8(prev_state_no_offset, _mm_setr_epi8(0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3));
//__m128i rem_reg_bins = _mm_loadu_epi32(&state->m_remRegBins[previous_state_array[0] & 0xfc]);
//rem_reg_bins = _mm_shuffle_epi8(rem_reg_bins, mask);
_mm_storeu_epi32(&state->m_remRegBins[state_offset], rem_reg_bins);
_mm_store_si128((__m128i*) & state->m_remRegBins[state_offset], rem_reg_bins);
} else {
const int temp = (state->effWidth * state->effHeight * 28) / 16;
for (int i = 0; i < 4; ++i) {
@ -1339,7 +1323,7 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
if (nbOut->num == 0) {
temp[id % 4] = 0;
if (id % 4 == 3) {
all[id / 4] = _mm256_loadu_epi64(temp);
all[id / 4] = _mm256_loadu_si256((__m256i const*)temp);
all[id / 4] = _mm256_shuffle_epi8(all[id / 4], v_shuffle);
}
continue;
@ -1427,8 +1411,8 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
__m128i shuffle_mask = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 0, 0, 0, 0, 0, 0, 0, 0);
__m128i shuffled_template_ctx_init = _mm_shuffle_epi8(template_ctx_init, shuffle_mask);
temp[id % 4] = _mm_extract_epi64(shuffled_template_ctx_init, 0);
if (id %4 == 3) {
all[id / 4] = _mm256_loadu_epi64(temp);
if (id % 4 == 3) {
all[id / 4] = _mm256_loadu_si256((__m256i const*)temp);
all[id / 4] = _mm256_shuffle_epi8(all[id / 4], v_shuffle);
last = template_ctx_init;
}
@ -1454,10 +1438,10 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
v_tmp[2] = _mm256_permute4x64_epi64(v_tmp16_hi[0], _MM_SHUFFLE(3, 1, 2, 0));
v_tmp[3] = _mm256_permute4x64_epi64(v_tmp16_hi[1], _MM_SHUFFLE(3, 1, 2, 0));
_mm256_storeu_epi16(state->m_absLevelsAndCtxInit[state_offset] + 8, _mm256_permute2x128_si256(v_tmp[0], v_tmp[1], 0x20));
_mm256_storeu_epi16(state->m_absLevelsAndCtxInit[state_offset + 1] + 8, _mm256_permute2x128_si256(v_tmp[0], v_tmp[1], 0x31));
_mm256_storeu_epi16(state->m_absLevelsAndCtxInit[state_offset + 2] + 8, _mm256_permute2x128_si256(v_tmp[2], v_tmp[3], 0x20));
_mm256_storeu_epi16(state->m_absLevelsAndCtxInit[state_offset + 3] + 8, _mm256_permute2x128_si256(v_tmp[2], v_tmp[3], 0x31));
_mm256_store_si256((__m256i*)(state->m_absLevelsAndCtxInit[state_offset] + 8), _mm256_permute2x128_si256(v_tmp[0], v_tmp[1], 0x20));
_mm256_store_si256((__m256i*)(state->m_absLevelsAndCtxInit[state_offset + 1] + 8), _mm256_permute2x128_si256(v_tmp[0], v_tmp[1], 0x31));
_mm256_store_si256((__m256i*)(state->m_absLevelsAndCtxInit[state_offset + 2] + 8), _mm256_permute2x128_si256(v_tmp[2], v_tmp[3], 0x20));
_mm256_store_si256((__m256i*)(state->m_absLevelsAndCtxInit[state_offset + 3] + 8), _mm256_permute2x128_si256(v_tmp[2], v_tmp[3], 0x31));
for (int i = 0; i < 4; ++i) {
memset(state->m_absLevelsAndCtxInit[state_offset + i], 0, 16);
@ -1479,13 +1463,13 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos,
offsets = _mm_add_epi32(offsets, _mm_set1_epi32(sigCtxOffsetNext));
offsets = _mm_add_epi32(offsets, sum_abs_min);
__m256i sig_frac_bits = _mm256_i32gather_epi64((const int64_t *)&state->m_sigFracBitsArray[state_offset][0][0], offsets, 8);
_mm256_storeu_epi64(&state->m_sigFracBits[state_offset][0], sig_frac_bits);
_mm256_store_si256((__m256i*)&state->m_sigFracBits[state_offset][0], sig_frac_bits);
__m128i sum_gt1 = _mm_sub_epi32(sum_abs1, sum_num);
__m128i min_gt1 = _mm_min_epi32(sum_gt1, _mm_set1_epi32(4));
uint32_t sum_gt1_s[4];
_mm_storeu_epi32(sum_gt1_s, min_gt1);
_mm_storeu_si128((__m128i*)sum_gt1_s, min_gt1);
for (int i = 0; i < 4; ++i) {
memcpy(state->m_coeffFracBits[state_offset + i], state->m_gtxFracBitsArray[sum_gt1_s[i] + gtxCtxOffsetNext], sizeof(state->m_coeffFracBits[0]));
}
@ -1592,22 +1576,22 @@ static INLINE void update_states_avx2(
all_minus_one &= decisions->prevId[i] == -1;
}
int state_offset = ctxs->m_curr_state_offset;
__m256i rd_cost = _mm256_loadu_epi64(decisions->rdCost);
_mm256_storeu_epi64(&ctxs->m_allStates.m_rdCost[state_offset], rd_cost);
__m256i rd_cost = _mm256_load_si256((__m256i const*)decisions->rdCost);
_mm256_store_si256((__m256i *)& ctxs->m_allStates.m_rdCost[state_offset], rd_cost);
if (all_above_minus_two) {
bool rem_reg_all_gte_4 = true;
bool rem_reg_all_lt4 = true;
__m128i abs_level = _mm_loadu_epi32(decisions->absLevel);
__m128i abs_level = _mm_load_si128((__m128i const*)decisions->absLevel);
if (all_non_negative) {
__m128i prv_states = _mm_loadu_epi32(decisions->prevId);
__m128i prv_states = _mm_load_si128((__m128i const*)decisions->prevId);
__m128i prev_offset = _mm_set1_epi32(ctxs->m_prev_state_offset);
prv_states = _mm_add_epi32(prv_states, prev_offset);
__m128i control = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1);
__m128i shuffled_prev_states = _mm_shuffle_epi8(prv_states, control);
__m128i sig_sbb = _mm_loadu_epi32(state->m_numSigSbb);
__m128i sig_sbb = _mm_load_si128((__m128i const*)state->m_numSigSbb);
sig_sbb = _mm_shuffle_epi8(sig_sbb, shuffled_prev_states);
__m128i has_coeff = _mm_min_epi32(abs_level, _mm_set1_epi32(1));
has_coeff = _mm_shuffle_epi8(has_coeff, control);
@ -1615,19 +1599,19 @@ static INLINE void update_states_avx2(
int sig_sbb_i = _mm_extract_epi32(sig_sbb, 0);
memcpy(&state->m_numSigSbb[state_offset], &sig_sbb_i, 4);
__m128i ref_sbb_ctx_idx = _mm_loadu_epi32(state->m_refSbbCtxId);
__m128i ref_sbb_ctx_idx = _mm_load_si128((__m128i const*)state->m_refSbbCtxId);
ref_sbb_ctx_idx = _mm_shuffle_epi8(ref_sbb_ctx_idx, shuffled_prev_states);
int ref_sbb_ctx = _mm_extract_epi32(ref_sbb_ctx_idx, 0);
memcpy(&state->m_refSbbCtxId[state_offset], &ref_sbb_ctx, 4);
__m128i go_rice_par = _mm_loadu_epi32(state->m_goRicePar);
__m128i go_rice_par = _mm_load_si128((__m128i const*)state->m_goRicePar);
go_rice_par = _mm_shuffle_epi8(go_rice_par, shuffled_prev_states);
int go_rice_par_i = _mm_extract_epi32(go_rice_par, 0);
memcpy(&state->m_goRicePar[state_offset], &go_rice_par_i, 4);
__m256i sbb_frac_bits = _mm256_i32gather_epi64((const int64_t *)state->m_sbbFracBits[0], prv_states, 8);
_mm256_storeu_epi64(&state->m_sbbFracBits[state_offset][0], sbb_frac_bits);
_mm256_store_si256((__m256i*)&state->m_sbbFracBits[state_offset][0], sbb_frac_bits);
__m128i rem_reg_bins = _mm_i32gather_epi32(state->m_remRegBins, prv_states, 4);
__m128i ones = _mm_set1_epi32(1);
@ -1640,7 +1624,7 @@ static INLINE void update_states_avx2(
__m128i rem_reg_bins_smaller_than_four = _mm_cmplt_epi32(rem_reg_bins, _mm_set1_epi32(4));
reg_bins_sub = _mm_blendv_epi8(secondary, reg_bins_sub, rem_reg_bins_smaller_than_four);
rem_reg_bins = _mm_sub_epi32(rem_reg_bins, reg_bins_sub);
_mm_storeu_epi32(&state->m_remRegBins[state_offset], rem_reg_bins);
_mm_store_si128((__m128i*)&state->m_remRegBins[state_offset], rem_reg_bins);
__m128i mask = _mm_cmpgt_epi32(rem_reg_bins, _mm_set1_epi32(3));
int bit_mask = _mm_movemask_epi8(mask);
@ -1650,7 +1634,7 @@ static INLINE void update_states_avx2(
rem_reg_all_lt4 = (bit_mask == 0xFFFF);
int32_t prv_states_scalar[4];
_mm_storeu_epi32(prv_states_scalar, prv_states);
_mm_storeu_si128((__m128i*)prv_states_scalar, prv_states);
for (int i = 0; i < 4; ++i) {
memcpy(state->m_absLevelsAndCtxInit[state_offset + i], state->m_absLevelsAndCtxInit[prv_states_scalar[i]], 48 * sizeof(uint8_t));
}
@ -1668,7 +1652,7 @@ static INLINE void update_states_avx2(
_mm_cmplt_epi32(abs_level, _mm_set1_epi32(2))
);
rem_reg_bins = _mm_sub_epi32(rem_reg_bins, sub);
_mm_storeu_epi32(&state->m_remRegBins[state_offset], rem_reg_bins);
_mm_store_si128((__m128i*) & state->m_remRegBins[state_offset], rem_reg_bins);
__m128i mask = _mm_cmpgt_epi32(rem_reg_bins, _mm_set1_epi32(3));
int bit_mask = _mm_movemask_epi8(mask);
@ -1711,7 +1695,7 @@ static INLINE void update_states_avx2(
uint32_t level_offset = scan_pos & 15;
__m128i max_abs = _mm_min_epi32(abs_level, _mm_set1_epi32(32));
uint32_t max_abs_s[4];
_mm_storeu_epi32(max_abs_s, max_abs);
_mm_storeu_si128((__m128i*)max_abs_s, max_abs);
for (int i = 0; i < 4; ++i) {
uint8_t* levels = (uint8_t*)state->m_absLevelsAndCtxInit[state_offset + i];
levels[level_offset] = max_abs_s[i];
@ -1840,12 +1824,12 @@ static INLINE void update_states_avx2(
_mm_set1_epi32(3));
offsets = _mm_add_epi32(offsets, temp);
__m256i sig_frac_bits = _mm256_i32gather_epi64((const int64_t *)state->m_sigFracBitsArray[state_offset][0], offsets, 8);
_mm256_storeu_epi64(&state->m_sigFracBits[state_offset][0], sig_frac_bits);
_mm256_store_si256((__m256i*)&state->m_sigFracBits[state_offset][0], sig_frac_bits);
sum_gt1 = _mm_min_epi32(sum_gt1, _mm_set1_epi32(4));
sum_gt1 = _mm_add_epi32(sum_gt1, _mm_set1_epi32(gtxCtxOffsetNext));
uint32_t sum_gt1_s[4];
_mm_storeu_epi32(sum_gt1_s, sum_gt1);
_mm_storeu_si128((__m128i*)sum_gt1_s, sum_gt1);
for (int i = 0; i < 4; ++i) {
memcpy(state->m_coeffFracBits[state_offset + i], state->m_gtxFracBitsArray[sum_gt1_s[i]], sizeof(state->m_coeffFracBits[0]));
}