[avx2] Do decision cost comparison with avx2

This commit is contained in:
Joose Sainio 2023-04-13 15:20:49 +03:00
parent fcbd12fef3
commit 6e2eaf9d6b

View file

@ -97,8 +97,8 @@ typedef struct
typedef struct typedef struct
{ {
int64_t rdCost[8]; int64_t rdCost[8];
coeff_t absLevel[8]; int32_t absLevel[8];
int prevId[8]; int32_t prevId[8];
} Decision; } Decision;
@ -880,76 +880,33 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en
rd_cost_a = _mm256_permute4x64_epi64(rd_cost_a, 216); rd_cost_a = _mm256_permute4x64_epi64(rd_cost_a, 216);
rd_cost_b = _mm256_permute4x64_epi64(rd_cost_b, 141); rd_cost_b = _mm256_permute4x64_epi64(rd_cost_b, 141);
rd_cost_z = _mm256_permute4x64_epi64(rd_cost_z, 216); rd_cost_z = _mm256_permute4x64_epi64(rd_cost_z, 216);
__m256i rd_cost_decision = _mm256_loadu_epi64(decisions->rdCost);
__m256i decision_data; __m256i decision_abs_coeff = _mm256_loadu_epi32(decisions->absLevel);
__m256i decision_prev_state = _mm256_loadu_epi32(decisions->prevId);
__m256i decision_data = _mm256_permute2x128_si256(decision_abs_coeff, decision_prev_state, 0x20);
__m256i mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
decision_data = _mm256_permutevar8x32_epi32(decision_data, mask);
// Decision 0 __m256i a_data = _mm256_set_epi32(3, pqDataA->absLevel[3], 1, pqDataA->absLevel[0], 2, pqDataA->absLevel[3], 0, pqDataA->absLevel[0]);
if (temp_rd_cost_a[0] < decisions->rdCost[0]) { __m256i b_data = _mm256_set_epi32(2, pqDataA->absLevel[1], 0, pqDataA->absLevel[2], 3, pqDataA->absLevel[1], 1, pqDataA->absLevel[2]);
decisions->rdCost[0] = temp_rd_cost_a[0]; __m256i z_data = _mm256_set_epi32(3, 0, 1, 0, 2, 0, 0, 0);
decisions->absLevel[0] = pqDataA->absLevel[0];
decisions->prevId[0] = 0;
}
if (temp_rd_cost_z[0] < decisions->rdCost[0]) {
decisions->rdCost[0] = temp_rd_cost_z[0];
decisions->absLevel[0] = 0;
decisions->prevId[0] = 0;
}
if (temp_rd_cost_b[1] < decisions->rdCost[0]) {
decisions->rdCost[0] = temp_rd_cost_b[1];
decisions->absLevel[0] = pqDataA->absLevel[2];
decisions->prevId[0] = 1;
}
// Decision 2 __m256i a_vs_b = _mm256_cmpgt_epi64(rd_cost_a, rd_cost_b);
if (temp_rd_cost_a[1] < decisions->rdCost[2]) { __m256i cheaper_first = _mm256_blendv_epi8(rd_cost_a, rd_cost_b, a_vs_b);
decisions->rdCost[2] = temp_rd_cost_a[1]; __m256i cheaper_first_data = _mm256_blendv_epi8(a_data, b_data, a_vs_b);
decisions->absLevel[2] = pqDataA->absLevel[0];
decisions->prevId[2] =1;
}
if (temp_rd_cost_z[1] < decisions->rdCost[2]) {
decisions->rdCost[2] = temp_rd_cost_z[1];
decisions->absLevel[2] = 0;
decisions->prevId[2] = 1;
}
if (temp_rd_cost_b[0] < decisions->rdCost[2]) {
decisions->rdCost[2] = temp_rd_cost_b[0];
decisions->absLevel[2] = pqDataA->absLevel[2];
decisions->prevId[2] = 0;
}
// Decision 1 __m256i z_vs_decision = _mm256_cmpgt_epi64(rd_cost_z, rd_cost_decision);
if (temp_rd_cost_a[2] < decisions->rdCost[1]) { __m256i cheaper_second = _mm256_blendv_epi8(rd_cost_z, rd_cost_decision, z_vs_decision);
decisions->rdCost[1] = temp_rd_cost_a[2]; __m256i cheaper_second_data = _mm256_blendv_epi8(z_data, decision_data, z_vs_decision);
decisions->absLevel[1] = pqDataA->absLevel[3];
decisions->prevId[1] = 2;
}
if (temp_rd_cost_z[2] < decisions->rdCost[1]) {
decisions->rdCost[1] = temp_rd_cost_z[2];
decisions->absLevel[1] = 0;
decisions->prevId[1] = 2;
}
if (temp_rd_cost_b[3] < decisions->rdCost[1]) {
decisions->rdCost[1] = temp_rd_cost_b[3];
decisions->absLevel[1] = pqDataA->absLevel[1];
decisions->prevId[1] = 3;
}
// Decision 3 __m256i final_decision = _mm256_cmpgt_epi64(cheaper_first, cheaper_second);
if (temp_rd_cost_a[3] < decisions->rdCost[3]) { __m256i final_rd_cost = _mm256_blendv_epi8(cheaper_first, cheaper_second, final_decision);
decisions->rdCost[3] = temp_rd_cost_a[3]; __m256i final_data = _mm256_blendv_epi8(cheaper_first_data, cheaper_second_data, final_decision);
decisions->absLevel[3] = pqDataA->absLevel[3];
decisions->prevId[3] = 3; _mm256_storeu_epi64(decisions->rdCost, final_rd_cost);
} final_data = _mm256_permutevar8x32_epi32(final_data, _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0));
if (temp_rd_cost_z[3] < decisions->rdCost[3]) { _mm256_storeu2_m128i(decisions->prevId, decisions->absLevel, final_data);
decisions->rdCost[3] = temp_rd_cost_z[3];
decisions->absLevel[3] = 0;
decisions->prevId[3] = 3;
}
if (temp_rd_cost_b[2] < decisions->rdCost[3]) {
decisions->rdCost[3] = temp_rd_cost_b[2];
decisions->absLevel[3] = pqDataA->absLevel[1];
decisions->prevId[3] = 2;
}
} }
@ -1316,8 +1273,7 @@ static INLINE void update_states_avx2(
bool rem_reg_all_gte_4 = true; bool rem_reg_all_gte_4 = true;
bool rem_reg_all_lt4 = true; bool rem_reg_all_lt4 = true;
__m128i abs_level = _mm_loadu_epi16(decisions->absLevel); __m128i abs_level = _mm_loadu_epi32(decisions->absLevel);
abs_level = _mm_cvtepi16_epi32(abs_level);
if (all_non_negative) { if (all_non_negative) {
__m128i prv_states = _mm_loadu_epi32(decisions->prevId); __m128i prv_states = _mm_loadu_epi32(decisions->prevId);
__m128i prev_offset = _mm_set1_epi32(ctxs->m_prev_state_offset); __m128i prev_offset = _mm_set1_epi32(ctxs->m_prev_state_offset);
@ -1978,8 +1934,8 @@ static void xDecideAndUpdate(
updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 1); updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 1);
updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 2); updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 2);
updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 3); updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 3);
memcpy(decisions->prevId + 4, decisions->prevId, 4 * sizeof(int)); memcpy(decisions->prevId + 4, decisions->prevId, 4 * sizeof(int32_t));
memcpy(decisions->absLevel + 4, decisions->absLevel, 4 * sizeof(coeff_t)); memcpy(decisions->absLevel + 4, decisions->absLevel, 4 * sizeof(int32_t));
memcpy(decisions->rdCost + 4, decisions->rdCost, 4 * sizeof(int64_t)); memcpy(decisions->rdCost + 4, decisions->rdCost, 4 * sizeof(int64_t));
} else if (!zeroOut) { } else if (!zeroOut) {
update_states_avx2(ctxs, next_nb_info_ssb.num, scan_pos, decisions, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], next_nb_info_ssb, 4, false); update_states_avx2(ctxs, next_nb_info_ssb.num, scan_pos, decisions, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], next_nb_info_ssb, 4, false);