From c6e6f5da339a1ad164d8d391446bbdad724bb1a8 Mon Sep 17 00:00:00 2001 From: Joose Sainio Date: Sat, 8 Apr 2023 18:58:40 +0300 Subject: [PATCH] [avx2] WIP check_rd_costs_avx2, almost? --- src/dep_quant.c | 113 +++++++++++++++++++++++------------------------- 1 file changed, 55 insertions(+), 58 deletions(-) diff --git a/src/dep_quant.c b/src/dep_quant.c index 2ff848b1..f272ad6e 100644 --- a/src/dep_quant.c +++ b/src/dep_quant.c @@ -563,10 +563,6 @@ static INLINE void checkRdCostSkipSbbZeroOut( static void check_rd_costs_avx2(const all_depquant_states* const state, const enum ScanPosType spt, const PQData* pqDataA, Decision* decisions, int start) { - int32_t a[64] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64}; - __m128i offsets = _mm_set_epi32(12, 8, 4, 0); - __m128i r = _mm_i32gather_epi32(a, offsets, 1); - int64_t temp_rd_cost_a[4] = {0, 0, 0, 0}; int64_t temp_rd_cost_b[4] = {0, 0, 0, 0}; int64_t temp_rd_cost_z[4] = {0, 0, 0, 0}; @@ -600,15 +596,15 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en __m128i offsets = _mm_set_epi32(18 + pqDataA->absLevel[3], 12 + pqDataA->absLevel[3], 6 + pqDataA->absLevel[0], 0 + pqDataA->absLevel[0]); __m128i t = _mm_slli_epi32(value, 1); offsets = _mm_sub_epi32(offsets, t); - __m128i coeff_frac_bits = _mm_i32gather_epi32(state->m_coeffFracBits[start], offsets, 1); + __m128i coeff_frac_bits = _mm_i32gather_epi32(state->m_coeffFracBits[start], offsets, 4); - __m128i max_rice = _mm_set1_epi32(15); + __m128i max_rice = _mm_set1_epi32(31); value = _mm_min_epi32(value, max_rice); __m128i go_rice_tab = _mm_cvtepi8_epi32(_mm_loadu_si32(&state->m_goRicePar[start])); go_rice_tab = _mm_slli_epi32(value, 5); value = _mm_add_epi32(value, go_rice_tab); - __m128i temp = _mm_add_epi32(coeff_frac_bits, _mm_i32gather_epi32(&g_goRiceBits[0][0], value, 1)); + __m128i temp = _mm_add_epi32(coeff_frac_bits, _mm_i32gather_epi32(&g_goRiceBits[0][0], value, 4)); rd_cost_a = _mm256_add_epi64(rd_cost_a, _mm256_cvtepi32_epi64(temp)); } else { const int pqAs[4] = {0, 0, 3, 3}; @@ -629,7 +625,7 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en if (pqDataA->absLevel[1] < 4 && pqDataA->absLevel[2] < 4) { __m128i offsets = _mm_set_epi32(18 + pqDataA->absLevel[1], 12 + pqDataA->absLevel[1], 6 + pqDataA->absLevel[2], 0 + pqDataA->absLevel[2]); - __m128i coeff_frac_bits = _mm_i32gather_epi32(state->m_coeffFracBits[start], offsets, 1); + __m128i coeff_frac_bits = _mm_i32gather_epi32(state->m_coeffFracBits[start], offsets, 4); __m256i ext_frac_bits = _mm256_cvtepi32_epi64(coeff_frac_bits); rd_cost_b = _mm256_add_epi64(rd_cost_b, ext_frac_bits); } else if (pqDataA->absLevel[1] >= 4 && pqDataA->absLevel[2] >= 4) { @@ -638,28 +634,28 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en __m128i offsets = _mm_set_epi32(18 + pqDataA->absLevel[1], 12 + pqDataA->absLevel[1], 6 + pqDataA->absLevel[2], 0 + pqDataA->absLevel[2]); __m128i t = _mm_slli_epi32(value, 1); offsets = _mm_sub_epi32(offsets, t); - __m128i coeff_frac_bits = _mm_i32gather_epi32(state->m_coeffFracBits[start], offsets, 1); + __m128i coeff_frac_bits = _mm_i32gather_epi32(state->m_coeffFracBits[start], offsets, 4); - __m128i max_rice = _mm_set1_epi32(15); + __m128i max_rice = _mm_set1_epi32(31); value = _mm_min_epi32(value, max_rice); __m128i go_rice_tab = _mm_cvtepi8_epi32(_mm_loadu_si32(&state->m_goRicePar[start])); go_rice_tab = _mm_slli_epi32(go_rice_tab, 5); value = _mm_add_epi32(value, go_rice_tab); - __m128i temp = _mm_add_epi32(coeff_frac_bits, _mm_i32gather_epi32(&g_goRiceBits[0][0], value, 1)); + __m128i temp = _mm_add_epi32(coeff_frac_bits, _mm_i32gather_epi32(&g_goRiceBits[0][0], value, 4)); rd_cost_b = _mm256_add_epi64(rd_cost_b, _mm256_cvtepi32_epi64(temp)); } else { - const int pqAs[4] = {0, 0, 3, 3}; + const int pqBs[4] = {2, 2, 1, 1}; int64_t rd_costs[4] = {0, 0, 0, 0}; for (int i = 0; i < 4; i++) { const int state_offset = start + i; - const int pqA = pqAs[i]; + const int pqB = pqBs[i]; const int32_t* goRiceTab = g_goRiceBits[state->m_goRicePar[state_offset]]; - if (pqDataA->absLevel[pqA] < 4) { - rd_costs[i] = state->m_coeffFracBits[state_offset][pqDataA->absLevel[pqA]]; + if (pqDataA->absLevel[pqB] < 4) { + rd_costs[i] = state->m_coeffFracBits[state_offset][pqDataA->absLevel[pqB]]; } else { - const coeff_t value = (pqDataA->absLevel[pqA] - 4) >> 1; - rd_costs[i] += state->m_coeffFracBits[state_offset][pqDataA->absLevel[pqA] - (value << 1)] + goRiceTab[value < RICEMAX ? value : RICEMAX - 1]; + const coeff_t value = (pqDataA->absLevel[pqB] - 4) >> 1; + rd_costs[i] += state->m_coeffFracBits[state_offset][pqDataA->absLevel[pqB] - (value << 1)] + goRiceTab[value < RICEMAX ? value : RICEMAX - 1]; } } rd_cost_b = _mm256_add_epi64(rd_cost_b, _mm256_loadu_si256(&rd_costs[0])); @@ -672,7 +668,7 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en __m256i even = _mm256_permutevar8x32_epi32(original, even_mask); __m256i odd = _mm256_permutevar8x32_epi32(original, odd_mask); __m256i even_64 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(even, 0)); - __m256i odd_64 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(odd, 1)); + __m256i odd_64 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(odd, 0)); rd_cost_a = _mm256_add_epi64(rd_cost_a, odd_64); rd_cost_b = _mm256_add_epi64(rd_cost_b, odd_64); rd_cost_z = _mm256_add_epi64(rd_cost_z, even_64); @@ -683,11 +679,11 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en __m256i even = _mm256_permutevar8x32_epi32(original, even_mask); __m256i odd = _mm256_permutevar8x32_epi32(original, odd_mask); __m256i m_sigFracBits_0 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(even, 0)); - __m256i m_sigFracBits_1 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(odd, 1)); + __m256i m_sigFracBits_1 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(odd, 0)); original = _mm256_loadu_si256((__m256i const*)state->m_sbbFracBits[start]); odd = _mm256_permutevar8x32_epi32(original, odd_mask); - __m256i m_sbbFracBits_1 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(odd, 1)); + __m256i m_sbbFracBits_1 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(odd, 0)); rd_cost_a = _mm256_add_epi64(rd_cost_a, m_sbbFracBits_1); @@ -706,22 +702,26 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en __m256i even = _mm256_permutevar8x32_epi32(original, even_mask); __m256i odd = _mm256_permutevar8x32_epi32(original, odd_mask); __m256i even_64 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(even, 0)); - __m256i odd_64 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(odd, 1)); + __m256i odd_64 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(odd, 0)); rd_cost_a = _mm256_add_epi64(rd_cost_a, odd_64); rd_cost_b = _mm256_add_epi64(rd_cost_b, odd_64); - rd_cost_z = _mm256_add_epi64(rd_cost_z, even_64); + rd_cost_z = _mm256_add_epi64(rd_cost_z, even_64); + _mm256_storeu_epi64(temp_rd_cost_a, rd_cost_a); + _mm256_storeu_epi64(temp_rd_cost_b, rd_cost_b); + _mm256_storeu_epi64(temp_rd_cost_z, rd_cost_z); } else if (!state->m_numSigSbb[start] && !state->m_numSigSbb[start + 1] && !state->m_numSigSbb[start + 2] && !state->m_numSigSbb[start + 3]) { - rd_cost_z = _mm256_setr_epi64x(decisions->rdCost[3], decisions->rdCost[3], decisions->rdCost[0], decisions->rdCost[0]); + rd_cost_z = _mm256_setr_epi64x(decisions->rdCost[0], decisions->rdCost[0], decisions->rdCost[3], decisions->rdCost[3]); + _mm256_storeu_epi64(temp_rd_cost_a, rd_cost_a); + _mm256_storeu_epi64(temp_rd_cost_b, rd_cost_b); + _mm256_storeu_epi64(temp_rd_cost_z, rd_cost_z); } else { const int pqAs[4] = {0, 0, 3, 3}; - int64_t temp_rd_cost_a[4] = {0, 0, 0, 0}; - int64_t temp_rd_cost_b[4] = {0, 0, 0, 0}; - int64_t temp_rd_cost_z[4] = {0, 0, 0, 0}; - int64_t z_out[4] = {0, 0, 0, 0}; - _mm256_storeu_epi64(z_out, rd_cost_z); + _mm256_storeu_epi64(temp_rd_cost_a, rd_cost_a); + _mm256_storeu_epi64(temp_rd_cost_b, rd_cost_b); + _mm256_storeu_epi64(temp_rd_cost_z, rd_cost_z); for (int i = 0; i < 4; i++) { const int state_offset = start + i; if (state->m_numSigSbb[state_offset]) { @@ -729,13 +729,9 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en temp_rd_cost_b[i] += state->m_sigFracBits[state_offset][1]; temp_rd_cost_z[i] += state->m_sigFracBits[state_offset][0]; } else { - z_out[i] = decisions->rdCost[pqAs[i]]; + temp_rd_cost_z[i] = decisions->rdCost[pqAs[i]]; } } - rd_cost_z = _mm256_loadu_epi64(z_out); - rd_cost_a = _mm256_add_epi64(rd_cost_a, _mm256_loadu_epi64(temp_rd_cost_a)); - rd_cost_b = _mm256_add_epi64(rd_cost_b, _mm256_loadu_epi64(temp_rd_cost_b)); - rd_cost_z = _mm256_add_epi64(rd_cost_z, _mm256_loadu_epi64(temp_rd_cost_z)); } } _mm256_storeu_epi64(temp_rd_cost_a, rd_cost_a); @@ -743,25 +739,25 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en _mm256_storeu_epi64(temp_rd_cost_z, rd_cost_z); } else if (all_under_four) { __m128i scale_bits = _mm_set1_epi32(1 << SCALE_BITS); - __m128i max_rice = _mm_set1_epi32(15); + __m128i max_rice = _mm_set1_epi32(31); __m128i go_rice_zero = _mm_cvtepi8_epi32(_mm_loadu_epi8(&state->m_goRiceZero[start])); // RD cost A { __m128i pq_abs_a = _mm_set_epi32(pqDataA->absLevel[3], pqDataA->absLevel[3], pqDataA->absLevel[0], pqDataA->absLevel[0]); - __m128i cmp = _mm_cmplt_epi32(go_rice_zero, pq_abs_a); + __m128i cmp = _mm_cmpgt_epi32(pq_abs_a, go_rice_zero); __m128i go_rice_smaller = _mm_min_epi32(pq_abs_a, max_rice); __m128i other = _mm_sub_epi32(pq_abs_a, _mm_set1_epi32(1)); - __m128i selected = _mm_blendv_epi8(go_rice_smaller, other, cmp); + __m128i selected = _mm_blendv_epi8(other, go_rice_smaller, cmp); __m128i go_rice_offset = _mm_cvtepi8_epi32(_mm_loadu_si32(&state->m_goRicePar[start])); go_rice_offset = _mm_slli_epi32(go_rice_offset, 5); __m128i offsets = _mm_add_epi32(selected, go_rice_offset); - __m128i go_rice_tab = _mm_i32gather_epi32(&g_goRiceBits[0][0], offsets, 1); + __m128i go_rice_tab = _mm_i32gather_epi32(&g_goRiceBits[0][0], offsets, 4); __m128i temp = _mm_add_epi32(go_rice_tab, scale_bits); rd_cost_a = _mm256_add_epi64(rd_cost_a, _mm256_cvtepi32_epi64(temp)); @@ -769,20 +765,20 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en // RD cost b { __m128i pq_abs_b = _mm_set_epi32(pqDataA->absLevel[1], pqDataA->absLevel[1], pqDataA->absLevel[2], pqDataA->absLevel[2]); - __m128i cmp = _mm_cmplt_epi32(go_rice_zero, pq_abs_b); + __m128i cmp = _mm_cmpgt_epi32(pq_abs_b, go_rice_zero); __m128i go_rice_smaller = _mm_min_epi32(pq_abs_b, max_rice); __m128i other = _mm_sub_epi32(pq_abs_b, _mm_set1_epi32(1)); - __m128i selected = _mm_blendv_epi8(go_rice_smaller, other, cmp); + __m128i selected = _mm_blendv_epi8(other, go_rice_smaller, cmp); __m128i go_rice_offset = _mm_cvtepi8_epi32(_mm_loadu_si32(&state->m_goRicePar[start])); go_rice_offset = _mm_slli_epi32(go_rice_offset, 5); __m128i offsets = _mm_add_epi32(selected, go_rice_offset); - __m128i go_rice_tab = _mm_i32gather_epi32(&g_goRiceBits[0][0], offsets, 1); + __m128i go_rice_tab = _mm_i32gather_epi32(&g_goRiceBits[0][0], offsets, 4); __m128i temp = _mm_add_epi32(go_rice_tab, scale_bits); rd_cost_b = _mm256_add_epi64(rd_cost_b, _mm256_cvtepi32_epi64(temp)); @@ -793,7 +789,8 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en go_rice_offset = _mm_slli_epi32(go_rice_offset, 5); go_rice_offset = _mm_add_epi32(go_rice_offset, go_rice_zero); - rd_cost_z = _mm256_add_epi64(rd_cost_z, _mm256_cvtepi32_epi64(go_rice_offset)); + __m128i go_rice_tab = _mm_i32gather_epi32(&g_goRiceBits[0][0], go_rice_offset, 4); + rd_cost_z = _mm256_add_epi64(rd_cost_z, _mm256_cvtepi32_epi64(go_rice_tab)); } _mm256_storeu_epi64(temp_rd_cost_a, rd_cost_a); _mm256_storeu_epi64(temp_rd_cost_b, rd_cost_b); @@ -868,7 +865,7 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en decisions->prevId[0] = state->m_stateId[start + 1]; } - // Decision 1 + // Decision 2 if (temp_rd_cost_a[1] < decisions->rdCost[2]) { decisions->rdCost[2] = temp_rd_cost_a[1]; decisions->absLevel[2] = pqDataA->absLevel[0]; @@ -885,35 +882,35 @@ static void check_rd_costs_avx2(const all_depquant_states* const state, const en decisions->prevId[2] = state->m_stateId[start]; } - // Decision 2 - if (temp_rd_cost_a[2] < decisions->rdCost[0]) { - decisions->rdCost[2] = temp_rd_cost_a[2]; - decisions->absLevel[2] = pqDataA->absLevel[3]; - decisions->prevId[2] = state->m_stateId[start + 2]; + // Decision 1 + if (temp_rd_cost_a[2] < decisions->rdCost[1]) { + decisions->rdCost[1] = temp_rd_cost_a[2]; + decisions->absLevel[1] = pqDataA->absLevel[3]; + decisions->prevId[1] = state->m_stateId[start + 2]; } - if (temp_rd_cost_z[2] < decisions->rdCost[0]) { - decisions->rdCost[2] = temp_rd_cost_z[2]; - decisions->absLevel[2] = 0; - decisions->prevId[2] = state->m_stateId[start + 2]; + if (temp_rd_cost_z[2] < decisions->rdCost[1]) { + decisions->rdCost[1] = temp_rd_cost_z[2]; + decisions->absLevel[1] = 0; + decisions->prevId[1] = state->m_stateId[start + 2]; } - if (temp_rd_cost_b[3] < decisions->rdCost[0]) { - decisions->rdCost[2] = temp_rd_cost_b[3]; - decisions->absLevel[2] = pqDataA->absLevel[1]; - decisions->prevId[2] = state->m_stateId[start + 3]; + if (temp_rd_cost_b[3] < decisions->rdCost[1]) { + decisions->rdCost[1] = temp_rd_cost_b[3]; + decisions->absLevel[1] = pqDataA->absLevel[1]; + decisions->prevId[1] = state->m_stateId[start + 3]; } // Decision 3 - if (temp_rd_cost_a[3] < decisions->rdCost[1]) { + if (temp_rd_cost_a[3] < decisions->rdCost[3]) { decisions->rdCost[3] = temp_rd_cost_a[3]; decisions->absLevel[3] = pqDataA->absLevel[3]; decisions->prevId[3] = state->m_stateId[start + 3]; } - if (temp_rd_cost_z[3] < decisions->rdCost[1]) { + if (temp_rd_cost_z[3] < decisions->rdCost[3]) { decisions->rdCost[3] = temp_rd_cost_z[3]; decisions->absLevel[3] = 0; decisions->prevId[3] = state->m_stateId[start + 3]; } - if (temp_rd_cost_b[2] < decisions->rdCost[1]) { + if (temp_rd_cost_b[2] < decisions->rdCost[3]) { decisions->rdCost[3] = temp_rd_cost_b[2]; decisions->absLevel[3] = pqDataA->absLevel[1]; decisions->prevId[3] = state->m_stateId[start + 2];