From 9ab59fcc241822ca47fe0a18c65d3cd65cc18267 Mon Sep 17 00:00:00 2001 From: Joose Sainio Date: Tue, 18 Apr 2023 15:43:30 +0300 Subject: [PATCH] [depquant] update_state_eos_avx2 working --- src/dep_quant.c | 96 +++++++++++++++++++++++++++++-------------------- 1 file changed, 58 insertions(+), 38 deletions(-) diff --git a/src/dep_quant.c b/src/dep_quant.c index ef73d7ed..cef534fa 100644 --- a/src/dep_quant.c +++ b/src/dep_quant.c @@ -708,16 +708,10 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en __m256i odd_64 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(odd, 0)); rd_cost_a = _mm256_add_epi64(rd_cost_a, odd_64); rd_cost_b = _mm256_add_epi64(rd_cost_b, odd_64); - rd_cost_z = _mm256_add_epi64(rd_cost_z, even_64); - _mm256_storeu_epi64(temp_rd_cost_a, rd_cost_a); - _mm256_storeu_epi64(temp_rd_cost_b, rd_cost_b); - _mm256_storeu_epi64(temp_rd_cost_z, rd_cost_z); + rd_cost_z = _mm256_add_epi64(rd_cost_z, even_64); } else if (!state->m_numSigSbb[start] && !state->m_numSigSbb[start + 1] && !state->m_numSigSbb[start + 2] && !state->m_numSigSbb[start + 3]) { rd_cost_z = _mm256_setr_epi64x(decisions->rdCost[0], decisions->rdCost[0], decisions->rdCost[3], decisions->rdCost[3]); - _mm256_storeu_epi64(temp_rd_cost_a, rd_cost_a); - _mm256_storeu_epi64(temp_rd_cost_b, rd_cost_b); - _mm256_storeu_epi64(temp_rd_cost_z, rd_cost_z); } else { @@ -735,11 +729,11 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en temp_rd_cost_z[i] = decisions->rdCost[pqAs[i]]; } } + rd_cost_a = _mm256_loadu_epi64(temp_rd_cost_a); + rd_cost_b = _mm256_loadu_epi64(temp_rd_cost_b); + rd_cost_z = _mm256_loadu_epi64(temp_rd_cost_z); } } - _mm256_storeu_epi64(temp_rd_cost_a, rd_cost_a); - _mm256_storeu_epi64(temp_rd_cost_b, rd_cost_b); - _mm256_storeu_epi64(temp_rd_cost_z, rd_cost_z); } else if (state->all_lt_four) { __m128i scale_bits = _mm_set1_epi32(1 << SCALE_BITS); __m128i max_rice = _mm_set1_epi32(31); @@ -795,9 +789,6 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en __m128i go_rice_tab = _mm_i32gather_epi32(&g_goRiceBits[0][0], go_rice_offset, 4); rd_cost_z = _mm256_add_epi64(rd_cost_z, _mm256_cvtepi32_epi64(go_rice_tab)); } - _mm256_storeu_epi64(temp_rd_cost_a, rd_cost_a); - _mm256_storeu_epi64(temp_rd_cost_b, rd_cost_b); - _mm256_storeu_epi64(temp_rd_cost_z, rd_cost_z); } else { const int pqAs[4] = {0, 0, 3, 3}; const int pqBs[4] = {2, 2, 1, 1}; @@ -1206,25 +1197,22 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos, memset(state->m_absLevelsAndCtxInit[state_offset + i], 0, 16 * sizeof(uint8_t)); } } else if (all_between_zero_and_three) { - prev_state = _mm_set1_epi32(ctxs->m_skip_state_offset); + prev_state = _mm_set1_epi32(ctxs->m_prev_state_offset); prev_state = _mm_add_epi32( prev_state, - _mm_sub_epi32( - _mm_loadu_epi32(decisions->prevId), - _mm_set1_epi32(4) - ) + _mm_loadu_epi32(decisions->prevId) ); __m128i num_sig_sbb = _mm_i32gather_epi32(&state->m_numSigSbb[state_offset], prev_state, 1); num_sig_sbb = _mm_and_epi32(num_sig_sbb, _mm_set1_epi32(0xff)); - num_sig_sbb = _mm_and_epi32( + num_sig_sbb = _mm_add_epi32( num_sig_sbb, - _mm_max_epi32(abs_level, _mm_set1_epi32(1)) + _mm_min_epi32(abs_level, _mm_set1_epi32(1)) ); __m128i control = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); num_sig_sbb = _mm_shuffle_epi8(num_sig_sbb, control); int num_sig_sbb_s = _mm_extract_epi32(num_sig_sbb, 0); - memcpy(&state->m_refSbbCtxId[state_offset], &num_sig_sbb_s, 4); + memcpy(&state->m_numSigSbb[state_offset], &num_sig_sbb_s, 4); int32_t prev_state_scalar[4]; _mm_storeu_epi32(prev_state_scalar, prev_state); @@ -1288,13 +1276,14 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos, __m128i sbb_offsets = _mm_set_epi32(3 * numSbb, 2 * numSbb, 1 * numSbb, 0); __m128i next_sbb_right_m = _mm_set1_epi32(next_sbb_right); __m128i sbb_offsets_right = _mm_add_epi32(sbb_offsets, next_sbb_right_m); - __m128i sbb_right = next_sbb_right ? _mm_i32gather_epi32(&cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].sbbFlags, sbb_offsets_right, 1) : _mm_set1_epi32(0); + __m128i sbb_right = next_sbb_right ? _mm_i32gather_epi32(cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].sbbFlags, sbb_offsets_right, 1) : _mm_set1_epi32(0); __m128i sbb_offsets_below = _mm_add_epi32(sbb_offsets, _mm_set1_epi32(next_sbb_below)); - __m128i sbb_below = next_sbb_right ? _mm_i32gather_epi32(&cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].sbbFlags, sbb_offsets_below, 1) : _mm_set1_epi32(0); + __m128i sbb_below = next_sbb_below ? _mm_i32gather_epi32(cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].sbbFlags, sbb_offsets_below, 1) : _mm_set1_epi32(0); __m128i sig_sbb = _mm_or_epi32(sbb_right, sbb_below); - sig_sbb = _mm_max_epi32(sig_sbb, _mm_set1_epi32(1)); + sig_sbb = _mm_and_epi32(sig_sbb, _mm_set1_epi32(0xff)); + sig_sbb = _mm_min_epi32(sig_sbb, _mm_set1_epi32(1)); __m256i sbb_frac_bits = _mm256_i32gather_epi64(cc->m_sbbFlagBits, sig_sbb, 8); _mm256_storeu_epi64(state->m_sbbFracBits[state_offset], sbb_frac_bits); @@ -1327,11 +1316,15 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos, __m128i fours = _mm_set1_epi32(4); __m256i all[4]; uint64_t temp[4]; + const __m256i v_shuffle = _mm256_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0, + 31, 30, 23, 22, 29, 28, 21, 20, 27, 26, 19, 18, 25, 24, 17, 16); + for (int id = 0; id < 16; id++, nbOut++) { if (nbOut->num == 0) { temp[id % 4] = 0; if (id % 4 == 3) { - all[0] = _mm256_loadu_epi64(temp); + all[id / 4] = _mm256_loadu_epi64(temp); + all[id / 4] = _mm256_shuffle_epi8(all[id / 4], v_shuffle); } continue; } @@ -1345,7 +1338,7 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos, __m128i t = _mm_i32gather_epi32(absLevels, offset, 1); t = _mm_and_epi32(t, first_byte); sum_abs = _mm_add_epi32(sum_abs, t); - sum_num = _mm_add_epi32(sum_num, _mm_max_epi32(t, ones)); + sum_num = _mm_add_epi32(sum_num, _mm_min_epi32(t, ones)); __m128i min_t = _mm_min_epi32( t, _mm_add_epi32( @@ -1360,7 +1353,7 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos, __m128i t = _mm_i32gather_epi32(absLevels, offset, 1); t = _mm_and_epi32(t, first_byte); sum_abs = _mm_add_epi32(sum_abs, t); - sum_num = _mm_add_epi32(sum_num, _mm_max_epi32(t, ones)); + sum_num = _mm_add_epi32(sum_num, _mm_min_epi32(t, ones)); __m128i min_t = _mm_min_epi32( t, _mm_add_epi32( @@ -1373,7 +1366,7 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos, __m128i t = _mm_i32gather_epi32(absLevels, offset, 1); t = _mm_and_epi32(t, first_byte); sum_abs = _mm_add_epi32(sum_abs, t); - sum_num = _mm_add_epi32(sum_num, _mm_max_epi32(t, ones)); + sum_num = _mm_add_epi32(sum_num, _mm_min_epi32(t, ones)); __m128i min_t = _mm_min_epi32( t, _mm_add_epi32( @@ -1386,7 +1379,7 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos, __m128i t = _mm_i32gather_epi32(absLevels, offset, 1); t = _mm_and_epi32(t, first_byte); sum_abs = _mm_add_epi32(sum_abs, t); - sum_num = _mm_add_epi32(sum_num, _mm_max_epi32(t, ones)); + sum_num = _mm_add_epi32(sum_num, _mm_min_epi32(t, ones)); __m128i min_t = _mm_min_epi32( t, _mm_add_epi32( @@ -1399,7 +1392,7 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos, __m128i t = _mm_i32gather_epi32(absLevels, offset, 1); t = _mm_and_epi32(t, first_byte); sum_abs = _mm_add_epi32(sum_abs, t); - sum_num = _mm_add_epi32(sum_num, _mm_max_epi32(t, ones)); + sum_num = _mm_add_epi32(sum_num, _mm_min_epi32(t, ones)); __m128i min_t = _mm_min_epi32( t, _mm_add_epi32( @@ -1414,16 +1407,42 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos, sum_abs_1 = _mm_slli_epi32(sum_abs_1, 3); sum_abs = _mm_slli_epi32(_mm_min_epi32(_mm_set1_epi32(127), sum_abs), 8); __m128i template_ctx_init = _mm_add_epi32(sum_num, sum_abs); - _mm_add_epi32(template_ctx_init, sum_abs_1); + template_ctx_init = _mm_add_epi32(template_ctx_init, sum_abs_1); __m128i shuffle_mask = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 0, 0, 0, 0, 0, 0, 0, 0); __m128i shuffled_template_ctx_init = _mm_shuffle_epi8(template_ctx_init, shuffle_mask); temp[id % 4] = _mm_extract_epi64(shuffled_template_ctx_init, 0); if (id %4 == 3) { - all[0] = _mm256_loadu_epi64(temp); + all[id / 4] = _mm256_loadu_epi64(temp); + all[id / 4] = _mm256_shuffle_epi8(all[id / 4], v_shuffle); last = template_ctx_init; } } + __m256i* v_src_tmp = all; + + __m256i v_tmp[4]; + v_tmp[0] = _mm256_permute2x128_si256(v_src_tmp[0], v_src_tmp[1], 0x20); + v_tmp[1] = _mm256_permute2x128_si256(v_src_tmp[0], v_src_tmp[1], 0x31); + v_tmp[2] = _mm256_permute2x128_si256(v_src_tmp[2], v_src_tmp[3], 0x20); + v_tmp[3] = _mm256_permute2x128_si256(v_src_tmp[2], v_src_tmp[3], 0x31); + + __m256i v_tmp16_lo[2]; + __m256i v_tmp16_hi[2]; + v_tmp16_lo[0] = _mm256_unpacklo_epi32(v_tmp[0], v_tmp[1]); + v_tmp16_lo[1] = _mm256_unpacklo_epi32(v_tmp[2], v_tmp[3]); + v_tmp16_hi[0] = _mm256_unpackhi_epi32(v_tmp[0], v_tmp[1]); + v_tmp16_hi[1] = _mm256_unpackhi_epi32(v_tmp[2], v_tmp[3]); + + v_tmp[0] = _mm256_permute4x64_epi64(v_tmp16_lo[0], _MM_SHUFFLE(3, 1, 2, 0)); + v_tmp[1] = _mm256_permute4x64_epi64(v_tmp16_lo[1], _MM_SHUFFLE(3, 1, 2, 0)); + v_tmp[2] = _mm256_permute4x64_epi64(v_tmp16_hi[0], _MM_SHUFFLE(3, 1, 2, 0)); + v_tmp[3] = _mm256_permute4x64_epi64(v_tmp16_hi[1], _MM_SHUFFLE(3, 1, 2, 0)); + + _mm256_storeu_epi16(state->m_absLevelsAndCtxInit[state_offset] + 8, _mm256_permute2x128_si256(v_tmp[0], v_tmp[1], 0x20)); + _mm256_storeu_epi16(state->m_absLevelsAndCtxInit[state_offset + 1] + 8, _mm256_permute2x128_si256(v_tmp[0], v_tmp[1], 0x31)); + _mm256_storeu_epi16(state->m_absLevelsAndCtxInit[state_offset + 2] + 8, _mm256_permute2x128_si256(v_tmp[2], v_tmp[3], 0x20)); + _mm256_storeu_epi16(state->m_absLevelsAndCtxInit[state_offset + 3] + 8, _mm256_permute2x128_si256(v_tmp[2], v_tmp[3], 0x31)); + for (int i = 0; i < 4; ++i) { memset(state->m_absLevelsAndCtxInit[state_offset + i], 0, 16); } @@ -1442,6 +1461,7 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos, __m128i offsets = _mm_set_epi32(12 * 3, 12 * 2, 12 * 1, 12 * 0); offsets = _mm_add_epi32(offsets, _mm_set1_epi32(sigCtxOffsetNext)); + offsets = _mm_add_epi32(offsets, sum_abs_min); __m256i sig_frac_bits = _mm256_i32gather_epi64(state->m_sigFracBitsArray[state_offset][0], offsets, 8); _mm256_storeu_epi64(&state->m_sigFracBits[state_offset][0], sig_frac_bits); @@ -1451,7 +1471,7 @@ static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos, uint32_t sum_gt1_s[4]; _mm_storeu_epi32(sum_gt1_s, min_gt1); for (int i = 0; i < 4; ++i) { - memcpy(state->m_coeffFracBits[state_offset + i], state->m_gtxFracBitsArray[sum_gt1_s[i]], sizeof(state->m_coeffFracBits[0])); + memcpy(state->m_coeffFracBits[state_offset + i], state->m_gtxFracBitsArray[sum_gt1_s[i] + gtxCtxOffsetNext], sizeof(state->m_coeffFracBits[0])); } } else { @@ -2217,14 +2237,14 @@ static void xDecideAndUpdate( if (scan_pos) { if (!(scan_pos & 15)) { SWAP(ctxs->m_common_context.m_curr_sbb_ctx_offset, ctxs->m_common_context.m_prev_sbb_ctx_offset, int); - updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 0); - updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 1); - updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 2); - updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 3); + update_state_eos_avx2(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions); + //updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 0); + //updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 1); + //updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 2); + //updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 3); memcpy(decisions->prevId + 4, decisions->prevId, 4 * sizeof(int32_t)); memcpy(decisions->absLevel + 4, decisions->absLevel, 4 * sizeof(int32_t)); memcpy(decisions->rdCost + 4, decisions->rdCost, 4 * sizeof(int64_t)); - printf("\n"); } else if (!zeroOut) { update_states_avx2(ctxs, next_nb_info_ssb.num, scan_pos, decisions, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], next_nb_info_ssb, 4, false); /* updateState(ctxs, next_nb_info_ssb.num, scan_pos, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 0);