From aa48943c22530d1afffe4c67f043963b6c60f8f5 Mon Sep 17 00:00:00 2001 From: Joose Sainio Date: Wed, 12 Apr 2023 15:36:45 +0300 Subject: [PATCH] [avx2] Do decision cost comparison with avx2 --- src/dep_quant.c | 102 +++++++++++++++--------------------------------- 1 file changed, 32 insertions(+), 70 deletions(-) diff --git a/src/dep_quant.c b/src/dep_quant.c index 7de8828f..b6158f68 100644 --- a/src/dep_quant.c +++ b/src/dep_quant.c @@ -97,8 +97,8 @@ typedef struct typedef struct { int64_t rdCost[8]; - coeff_t absLevel[8]; - int prevId[8]; + int32_t absLevel[8]; + int32_t prevId[8]; } Decision; @@ -877,73 +877,36 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en rd_cost_b = _mm256_loadu_epi64(temp_rd_cost_b); rd_cost_z = _mm256_loadu_epi64(temp_rd_cost_z); } - // Decision 0 - if (temp_rd_cost_a[0] < decisions->rdCost[0]) { - decisions->rdCost[0] = temp_rd_cost_a[0]; - decisions->absLevel[0] = pqDataA->absLevel[0]; - decisions->prevId[0] = 0; - } - if (temp_rd_cost_z[0] < decisions->rdCost[0]) { - decisions->rdCost[0] = temp_rd_cost_z[0]; - decisions->absLevel[0] = 0; - decisions->prevId[0] = 0; - } - if (temp_rd_cost_b[1] < decisions->rdCost[0]) { - decisions->rdCost[0] = temp_rd_cost_b[1]; - decisions->absLevel[0] = pqDataA->absLevel[2]; - decisions->prevId[0] = 1; - } + rd_cost_a = _mm256_permute4x64_epi64(rd_cost_a, 216); + rd_cost_b = _mm256_permute4x64_epi64(rd_cost_b, 141); + rd_cost_z = _mm256_permute4x64_epi64(rd_cost_z, 216); + __m256i rd_cost_decision = _mm256_loadu_epi64(decisions->rdCost); - // Decision 2 - if (temp_rd_cost_a[1] < decisions->rdCost[2]) { - decisions->rdCost[2] = temp_rd_cost_a[1]; - decisions->absLevel[2] = pqDataA->absLevel[0]; - decisions->prevId[2] =1; - } - if (temp_rd_cost_z[1] < decisions->rdCost[2]) { - decisions->rdCost[2] = temp_rd_cost_z[1]; - decisions->absLevel[2] = 0; - decisions->prevId[2] = 1; - } - if (temp_rd_cost_b[0] < decisions->rdCost[2]) { - decisions->rdCost[2] = temp_rd_cost_b[0]; - decisions->absLevel[2] = pqDataA->absLevel[2]; - decisions->prevId[2] = 0; - } + __m256i decision_abs_coeff = _mm256_loadu_epi32(decisions->absLevel); + __m256i decision_prev_state = _mm256_loadu_epi32(decisions->prevId); + __m256i decision_data = _mm256_permute2x128_si256(decision_abs_coeff, decision_prev_state, 0x20); + __m256i mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + decision_data = _mm256_permutevar8x32_epi32(decision_data, mask); - // Decision 1 - if (temp_rd_cost_a[2] < decisions->rdCost[1]) { - decisions->rdCost[1] = temp_rd_cost_a[2]; - decisions->absLevel[1] = pqDataA->absLevel[3]; - decisions->prevId[1] = 2; - } - if (temp_rd_cost_z[2] < decisions->rdCost[1]) { - decisions->rdCost[1] = temp_rd_cost_z[2]; - decisions->absLevel[1] = 0; - decisions->prevId[1] = 2; - } - if (temp_rd_cost_b[3] < decisions->rdCost[1]) { - decisions->rdCost[1] = temp_rd_cost_b[3]; - decisions->absLevel[1] = pqDataA->absLevel[1]; - decisions->prevId[1] = 3; - } + __m256i a_data = _mm256_set_epi32(3, pqDataA->absLevel[3], 1, pqDataA->absLevel[0], 2, pqDataA->absLevel[3], 0, pqDataA->absLevel[0]); + __m256i b_data = _mm256_set_epi32(2, pqDataA->absLevel[1], 0, pqDataA->absLevel[2], 3, pqDataA->absLevel[1], 1, pqDataA->absLevel[2]); + __m256i z_data = _mm256_set_epi32(3, 0, 1, 0, 2, 0, 0, 0); - // Decision 3 - if (temp_rd_cost_a[3] < decisions->rdCost[3]) { - decisions->rdCost[3] = temp_rd_cost_a[3]; - decisions->absLevel[3] = pqDataA->absLevel[3]; - decisions->prevId[3] = 3; - } - if (temp_rd_cost_z[3] < decisions->rdCost[3]) { - decisions->rdCost[3] = temp_rd_cost_z[3]; - decisions->absLevel[3] = 0; - decisions->prevId[3] = 3; - } - if (temp_rd_cost_b[2] < decisions->rdCost[3]) { - decisions->rdCost[3] = temp_rd_cost_b[2]; - decisions->absLevel[3] = pqDataA->absLevel[1]; - decisions->prevId[3] = 2; - } + __m256i a_vs_b = _mm256_cmpgt_epi64(rd_cost_a, rd_cost_b); + __m256i cheaper_first = _mm256_blendv_epi8(rd_cost_a, rd_cost_b, a_vs_b); + __m256i cheaper_first_data = _mm256_blendv_epi8(a_data, b_data, a_vs_b); + + __m256i z_vs_decision = _mm256_cmpgt_epi64(rd_cost_z, rd_cost_decision); + __m256i cheaper_second = _mm256_blendv_epi8(rd_cost_z, rd_cost_decision, z_vs_decision); + __m256i cheaper_second_data = _mm256_blendv_epi8(z_data, decision_data, z_vs_decision); + + __m256i final_decision = _mm256_cmpgt_epi64(cheaper_first, cheaper_second); + __m256i final_rd_cost = _mm256_blendv_epi8(cheaper_first, cheaper_second, final_decision); + __m256i final_data = _mm256_blendv_epi8(cheaper_first_data, cheaper_second_data, final_decision); + + _mm256_storeu_epi64(decisions->rdCost, final_rd_cost); + final_data = _mm256_permutevar8x32_epi32(final_data, _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0)); + _mm256_storeu2_m128i(decisions->prevId, decisions->absLevel, final_data); } @@ -1310,8 +1273,7 @@ static INLINE void update_states_avx2( bool rem_reg_all_gte_4 = true; bool rem_reg_all_lt4 = true; - __m128i abs_level = _mm_loadu_epi16(decisions->absLevel); - abs_level = _mm_cvtepi16_epi32(abs_level); + __m128i abs_level = _mm_loadu_epi32(decisions->absLevel); if (all_non_negative) { __m128i prv_states = _mm_loadu_epi32(decisions->prevId); __m128i prev_offset = _mm_set1_epi32(ctxs->m_prev_state_offset); @@ -1972,8 +1934,8 @@ static void xDecideAndUpdate( updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 1); updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 2); updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 3); - memcpy(decisions->prevId + 4, decisions->prevId, 4 * sizeof(int)); - memcpy(decisions->absLevel + 4, decisions->absLevel, 4 * sizeof(coeff_t)); + memcpy(decisions->prevId + 4, decisions->prevId, 4 * sizeof(int32_t)); + memcpy(decisions->absLevel + 4, decisions->absLevel, 4 * sizeof(int32_t)); memcpy(decisions->rdCost + 4, decisions->rdCost, 4 * sizeof(int64_t)); } else if (!zeroOut) { update_states_avx2(ctxs, next_nb_info_ssb.num, scan_pos, decisions, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], next_nb_info_ssb, 4, false);