From 9e27b4056a42b6dbc18ab29f850d48c681450288 Mon Sep 17 00:00:00 2001 From: Joose Sainio Date: Mon, 17 Apr 2023 13:52:42 +0300 Subject: [PATCH] [avx2] WIP update_state_eos_avx2 --- src/dep_quant.c | 353 +++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 337 insertions(+), 16 deletions(-) diff --git a/src/dep_quant.c b/src/dep_quant.c index d01b9da6..6ea82fef 100644 --- a/src/dep_quant.c +++ b/src/dep_quant.c @@ -107,9 +107,11 @@ typedef struct const NbInfoOut* m_nbInfo; uint32_t m_sbbFlagBits[2][2]; SbbCtx m_allSbbCtx[8]; - SbbCtx* m_currSbbCtx; - SbbCtx* m_prevSbbCtx; - uint8_t m_memory[8 * (TR_MAX_WIDTH * TR_MAX_WIDTH + 1024)]; + int m_curr_sbb_ctx_offset; + int m_prev_sbb_ctx_offset; + uint8_t sbb_memory[8 * 1024]; + uint8_t level_memory[8* TR_MAX_WIDTH * TR_MAX_WIDTH]; + int num_coeff; } common_context; @@ -447,14 +449,15 @@ static void reset_common_context(common_context* ctx, const rate_estimator * rat { //memset(&ctx->m_nbInfo, 0, sizeof(ctx->m_nbInfo)); memcpy(&ctx->m_sbbFlagBits, &rate_estimator->m_sigSbbFracBits, sizeof(rate_estimator->m_sigSbbFracBits)); - const int chunkSize = numSbb + num_coeff; - uint8_t* nextMem = ctx->m_memory; - for (int k = 0; k < 8; k++, nextMem += chunkSize) { - ctx->m_allSbbCtx[k].sbbFlags = nextMem; - ctx->m_allSbbCtx[k].levels = nextMem + numSbb; + uint8_t* next_sbb_memory = ctx->sbb_memory; + uint8_t* next_level_memory = ctx->level_memory; + for (int k = 0; k < 8; k++, next_sbb_memory += numSbb, next_level_memory += num_coeff) { + ctx->m_allSbbCtx[k].sbbFlags = next_sbb_memory; + ctx->m_allSbbCtx[k].levels = next_level_memory; } - ctx->m_currSbbCtx = &ctx->m_allSbbCtx[0]; - ctx->m_prevSbbCtx = &ctx->m_allSbbCtx[4]; + ctx->m_curr_sbb_ctx_offset = 0; + ctx->m_prev_sbb_ctx_offset = 4; + ctx->num_coeff = num_coeff; } static void init_rate_esimator(rate_estimator * rate_estimator, const cabac_data_t * const ctx, color_t color) @@ -1121,12 +1124,12 @@ static INLINE void update_common_context( const int curr_state) { const uint32_t numSbb = width_in_sbb * height_in_sbb; - uint8_t* sbbFlags = cc->m_currSbbCtx[curr_state & 3].sbbFlags; - uint8_t* levels = cc->m_currSbbCtx[curr_state & 3].levels; + uint8_t* sbbFlags = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset + (curr_state & 3)].sbbFlags; + uint8_t* levels = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset + (curr_state & 3)].levels; size_t setCpSize = cc->m_nbInfo[scan_pos - 1].maxDist * sizeof(uint8_t); if (prev_state != -1 && ctxs->m_allStates.m_refSbbCtxId[prev_state] >= 0) { - memcpy(sbbFlags, cc->m_prevSbbCtx[ctxs->m_allStates.m_refSbbCtxId[prev_state]].sbbFlags, numSbb * sizeof(uint8_t)); - memcpy(levels + scan_pos, cc->m_prevSbbCtx[ctxs->m_allStates.m_refSbbCtxId[prev_state]].levels + scan_pos, setCpSize); + memcpy(sbbFlags, cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset + ctxs->m_allStates.m_refSbbCtxId[prev_state]].sbbFlags, numSbb * sizeof(uint8_t)); + memcpy(levels + scan_pos, cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset + ctxs->m_allStates.m_refSbbCtxId[prev_state]].levels + scan_pos, setCpSize); } else { memset(sbbFlags, 0, numSbb * sizeof(uint8_t)); @@ -1181,6 +1184,323 @@ static INLINE void update_common_context( memset(ctxs->m_allStates.m_absLevelsAndCtxInit[curr_state], 0, 16 * sizeof(uint8_t)); } +static INLINE void updateStateEOS( + context_store* ctxs, + const uint32_t scan_pos, + const uint32_t cg_pos, + const uint32_t sigCtxOffsetNext, + const uint32_t gtxCtxOffsetNext, + const uint32_t width_in_sbb, + const uint32_t height_in_sbb, + const uint32_t next_sbb_right, + const uint32_t next_sbb_below, + const Decision* decisions, + int decision_id); + +static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos, const uint32_t cg_pos, + const uint32_t sigCtxOffsetNext, const uint32_t gtxCtxOffsetNext, + const uint32_t width_in_sbb, const uint32_t height_in_sbb, + const uint32_t next_sbb_right, const uint32_t next_sbb_below, + const Decision* decisions) +{ + all_depquant_states* state = &ctxs->m_allStates; + bool all_above_minus_two = true; + bool all_between_zero_and_three = true; + bool all_above_four = true; + + + int state_offset = ctxs->m_curr_state_offset; + __m256i rd_cost = _mm256_loadu_epi64(decisions->rdCost); + _mm256_storeu_epi64(&ctxs->m_allStates.m_rdCost[state_offset], rd_cost); + for (int i = 0; i < 4; ++i) { + all_above_minus_two &= decisions->prevId[i] > -2; + all_between_zero_and_three &= decisions->prevId[i] >= 0 && decisions->prevId[i] < 4; + all_above_four &= decisions->prevId[i] >= 4; + } + if (all_above_minus_two) { + bool all_have_previous_state = true; + __m128i prev_state; + __m128i abs_level = _mm_loadu_epi32(decisions->absLevel); + if (all_above_four) { + prev_state = _mm_set1_epi32(ctxs->m_skip_state_offset); + prev_state = _mm_add_epi32( + prev_state, + _mm_sub_epi32( + _mm_loadu_epi32(decisions->prevId), + _mm_set1_epi32(4) + ) + ); + memset(&state->m_numSigSbb[state_offset], 0, 4); + for (int i = 0; i < 4; ++i) { + memset(state->m_absLevelsAndCtxInit[state_offset + i], 0, 16 * sizeof(uint8_t)); + } + } else if (all_between_zero_and_three) { + prev_state = _mm_set1_epi32(ctxs->m_skip_state_offset); + prev_state = _mm_add_epi32( + prev_state, + _mm_sub_epi32( + _mm_loadu_epi32(decisions->prevId), + _mm_set1_epi32(4) + ) + ); + __m128i num_sig_sbb = _mm_i32gather_epi32(&state->m_numSigSbb[state_offset], prev_state, 1); + num_sig_sbb = _mm_and_epi32(num_sig_sbb, _mm_set1_epi32(0xff)); + num_sig_sbb = _mm_and_epi32( + num_sig_sbb, + _mm_max_epi32(abs_level, _mm_set1_epi32(1)) + ); + + __m128i control = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); + num_sig_sbb = _mm_shuffle_epi8(num_sig_sbb, control); + int num_sig_sbb_s = _mm_extract_epi32(num_sig_sbb, 0); + memcpy(&state->m_refSbbCtxId[state_offset], &num_sig_sbb_s, 4); + + int32_t prev_state_scalar[4]; + _mm_storeu_epi32(prev_state_scalar, prev_state); + for (int i = 0; i < 4; ++i) { + memcpy(state->m_absLevelsAndCtxInit[state_offset + i], state->m_absLevelsAndCtxInit[prev_state_scalar[i]], 16 * sizeof(uint8_t)); + } + } else { + int prev_state_s[4] = {-1, -1, -1, -1}; + for (int i = 0; i < 4; ++i) { + const int decision_id = i; + const int curr_state_offset = state_offset + i; + if (decisions->prevId[decision_id] >= 4) { + prev_state_s[i] = ctxs->m_skip_state_offset + (decisions->prevId[decision_id] - 4); + state->m_numSigSbb[curr_state_offset] = 0; + memset(state->m_absLevelsAndCtxInit[curr_state_offset], 0, 16 * sizeof(uint8_t)); + } else if (decisions->prevId[decision_id] >= 0) { + prev_state_s[i] = ctxs->m_prev_state_offset + decisions->prevId[decision_id]; + state->m_numSigSbb[curr_state_offset] = state->m_numSigSbb[prev_state_s[i]] + !!decisions->absLevel[decision_id]; + memcpy(state->m_absLevelsAndCtxInit[curr_state_offset], state->m_absLevelsAndCtxInit[prev_state_s[i]], 16 * sizeof(uint8_t)); + } else { + state->m_numSigSbb[curr_state_offset] = 1; + memset(state->m_absLevelsAndCtxInit[curr_state_offset], 0, 16 * sizeof(uint8_t)); + all_have_previous_state = false; + } + } + prev_state = _mm_loadu_epi32(prev_state_s); + } + uint32_t level_offset = scan_pos & 15; + __m128i max_abs = _mm_min_epi32(abs_level, _mm_set1_epi32(32)); + uint32_t max_abs_s[4]; + _mm_storeu_epi32(max_abs_s, max_abs); + for (int i = 0; i < 4; ++i) { + uint8_t* levels = (uint8_t*)state->m_absLevelsAndCtxInit[state_offset + i]; + levels[level_offset] = max_abs_s[i]; + } + + // Update common context + __m128i last; + { + const uint32_t numSbb = width_in_sbb * height_in_sbb; + common_context* cc = &ctxs->m_common_context; + size_t setCpSize = cc->m_nbInfo[scan_pos - 1].maxDist * sizeof(uint8_t); + int previous_state_array[4]; + _mm_storeu_epi32(previous_state_array, prev_state); + for (int curr_state = 0; curr_state < 4; ++curr_state) { + uint8_t* sbbFlags = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset + (curr_state)].sbbFlags; + uint8_t* levels = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset + (curr_state)].levels; + const int p_state = previous_state_array[curr_state]; + if (p_state != -1 && ctxs->m_allStates.m_refSbbCtxId[p_state] >= 0) { + const int prev_sbb = cc->m_prev_sbb_ctx_offset + ctxs->m_allStates.m_refSbbCtxId[p_state]; + memcpy(sbbFlags, cc->m_allSbbCtx[prev_sbb].sbbFlags, numSbb * sizeof(uint8_t)); + memcpy(levels + scan_pos, cc->m_allSbbCtx[prev_sbb].levels + scan_pos, setCpSize); + } else { + memset(sbbFlags, 0, numSbb * sizeof(uint8_t)); + memset(levels + scan_pos, 0, setCpSize); + } + sbbFlags[cg_pos] = !!ctxs->m_allStates.m_numSigSbb[curr_state + state_offset]; + memcpy(levels + scan_pos, ctxs->m_allStates.m_absLevelsAndCtxInit[curr_state + state_offset], 16 * sizeof(uint8_t)); + } + + __m128i sbb_offsets = _mm_set_epi32(3 * numSbb, 2 * numSbb, 1 * numSbb, 0); + __m128i next_sbb_right_m = _mm_set1_epi32(next_sbb_right); + __m128i sbb_offsets_right = _mm_add_epi32(sbb_offsets, next_sbb_right_m); + __m128i sbb_right = next_sbb_right ? _mm_i32gather_epi32(&cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].sbbFlags, sbb_offsets_right, 1) : _mm_set1_epi32(0); + + __m128i sbb_offsets_below = _mm_add_epi32(sbb_offsets, _mm_set1_epi32(next_sbb_below)); + __m128i sbb_below = next_sbb_right ? _mm_i32gather_epi32(&cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].sbbFlags, sbb_offsets_below, 1) : _mm_set1_epi32(0); + + __m128i sig_sbb = _mm_or_epi32(sbb_right, sbb_below); + sig_sbb = _mm_max_epi32(sig_sbb, _mm_set1_epi32(1)); + __m256i sbb_frac_bits = _mm256_i32gather_epi64(cc->m_sbbFlagBits, sig_sbb, 8); + _mm256_storeu_epi64(state->m_sbbFracBits[state_offset], sbb_frac_bits); + + memset(&state->m_numSigSbb[state_offset], 0, 4); + memset(&state->m_goRicePar[state_offset], 0, 4); + + uint8_t states[4] = {0, 1, 2, 3}; + memcpy(&state->m_refSbbCtxId[state_offset], states, 4); + if (all_have_previous_state) { + __m128i rem_reg_bins = _mm_i32gather_epi32(state->m_remRegBins, prev_state, 4); + _mm_storeu_epi32(&state->m_remRegBins[state_offset], rem_reg_bins); + } else { + const int temp = (state->effWidth * state->effHeight * 28) / 16; + for (int i = 0; i < 4; ++i) { + if (previous_state_array[i] != -1) { + state->m_remRegBins[i + state_offset] = state->m_remRegBins[previous_state_array[i]]; + } else { + state->m_remRegBins[i + state_offset] = temp; + } + } + } + + const int scanBeg = scan_pos - 16; + const NbInfoOut* nbOut = cc->m_nbInfo + scanBeg; + const uint8_t* absLevels = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].levels + scanBeg; + + __m128i levels_offsets = _mm_set_epi32(cc->num_coeff * 3, cc->num_coeff * 2, cc->num_coeff * 1, 0); + __m128i first_byte = _mm_set1_epi32(0xff); + __m128i ones = _mm_set1_epi32(1); + __m128i fours = _mm_set1_epi32(4); + __m256i all[4]; + uint64_t temp[4]; + for (int id = 0; id < 16; id++, nbOut++) { + if (nbOut->num == 0) { + temp[id % 4] = 0; + if (id % 4 == 3) { + all[0] = _mm256_loadu_epi64(temp); + } + continue; + } + __m128i sum_abs = _mm_set1_epi32(0); + __m128i sum_abs_1 = _mm_set1_epi32(0); + __m128i sum_num = _mm_set1_epi32(0); + switch (nbOut->num) { + case 5: + { + __m128i offset = _mm_add_epi32(levels_offsets, _mm_set1_epi32(nbOut->outPos[4])); + __m128i t = _mm_i32gather_epi32(absLevels, offset, 1); + t = _mm_and_epi32(t, first_byte); + sum_abs = _mm_add_epi32(sum_abs, t); + sum_num = _mm_add_epi32(sum_num, _mm_max_epi32(t, ones)); + __m128i min_t = _mm_min_epi32( + t, + _mm_add_epi32( + fours, + _mm_and_epi32(t, ones) + ) + ); + sum_abs_1 = _mm_add_epi32(sum_abs_1, min_t); + } + case 4: { + __m128i offset = _mm_add_epi32(levels_offsets, _mm_set1_epi32(nbOut->outPos[3])); + __m128i t = _mm_i32gather_epi32(absLevels, offset, 1); + t = _mm_and_epi32(t, first_byte); + sum_abs = _mm_add_epi32(sum_abs, t); + sum_num = _mm_add_epi32(sum_num, _mm_max_epi32(t, ones)); + __m128i min_t = _mm_min_epi32( + t, + _mm_add_epi32( + fours, + _mm_and_epi32(t, ones))); + sum_abs_1 = _mm_add_epi32(sum_abs_1, min_t); + } + case 3: { + __m128i offset = _mm_add_epi32(levels_offsets, _mm_set1_epi32(nbOut->outPos[2])); + __m128i t = _mm_i32gather_epi32(absLevels, offset, 1); + t = _mm_and_epi32(t, first_byte); + sum_abs = _mm_add_epi32(sum_abs, t); + sum_num = _mm_add_epi32(sum_num, _mm_max_epi32(t, ones)); + __m128i min_t = _mm_min_epi32( + t, + _mm_add_epi32( + fours, + _mm_and_epi32(t, ones))); + sum_abs_1 = _mm_add_epi32(sum_abs_1, min_t); + } + case 2: { + __m128i offset = _mm_add_epi32(levels_offsets, _mm_set1_epi32(nbOut->outPos[1])); + __m128i t = _mm_i32gather_epi32(absLevels, offset, 1); + t = _mm_and_epi32(t, first_byte); + sum_abs = _mm_add_epi32(sum_abs, t); + sum_num = _mm_add_epi32(sum_num, _mm_max_epi32(t, ones)); + __m128i min_t = _mm_min_epi32( + t, + _mm_add_epi32( + fours, + _mm_and_epi32(t, ones))); + sum_abs_1 = _mm_add_epi32(sum_abs_1, min_t); + } + case 1: { + __m128i offset = _mm_add_epi32(levels_offsets, _mm_set1_epi32(nbOut->outPos[0])); + __m128i t = _mm_i32gather_epi32(absLevels, offset, 1); + t = _mm_and_epi32(t, first_byte); + sum_abs = _mm_add_epi32(sum_abs, t); + sum_num = _mm_add_epi32(sum_num, _mm_max_epi32(t, ones)); + __m128i min_t = _mm_min_epi32( + t, + _mm_add_epi32( + fours, + _mm_and_epi32(t, ones))); + sum_abs_1 = _mm_add_epi32(sum_abs_1, min_t); + } + break; + default: + assert(0); + } + sum_abs_1 = _mm_slli_epi32(sum_abs_1, 3); + sum_abs = _mm_slli_epi32(_mm_min_epi32(_mm_set1_epi32(127), sum_abs), 8); + __m128i template_ctx_init = _mm_add_epi32(sum_num, sum_abs); + _mm_add_epi32(template_ctx_init, sum_abs_1); + __m128i shuffle_mask = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 0, 0, 0, 0, 0, 0, 0, 0); + __m128i shuffled_template_ctx_init = _mm_shuffle_epi8(template_ctx_init, shuffle_mask); + temp[id % 4] = _mm_extract_epi64(shuffled_template_ctx_init, 0); + if (id %4 == 3) { + all[0] = _mm256_loadu_epi64(temp); + last = template_ctx_init; + } + } + + for (int i = 0; i < 4; ++i) { + memset(state->m_absLevelsAndCtxInit[state_offset + i], 0, 16); + } + } + + __m128i sum_num = _mm_and_epi32(last, _mm_set1_epi32(7)); + __m128i sum_abs1 = _mm_and_epi32( + _mm_srli_epi32(last, 3), + _mm_set1_epi32(31)); + + __m128i sum_abs_min = _mm_min_epi32( + _mm_set1_epi32(3), + _mm_srli_epi32( + _mm_add_epi32(sum_abs1, _mm_set1_epi32(1)), + 1)); + + __m128i offsets = _mm_set_epi32(12 * 3, 12 * 2, 12 * 1, 12 * 0); + offsets = _mm_add_epi32(offsets, _mm_set1_epi32(sigCtxOffsetNext)); + __m256i sig_frac_bits = _mm256_i32gather_epi64(state->m_sigFracBitsArray[state_offset][0], offsets, 8); + _mm256_storeu_epi64(&state->m_sigFracBits[state_offset][0], sig_frac_bits); + + + __m128i sum_gt1 = _mm_sub_epi32(sum_abs1, sum_num); + __m128i min_gt1 = _mm_min_epi32(sum_gt1, _mm_set1_epi32(4)); + uint32_t sum_gt1_s[4]; + _mm_storeu_epi32(sum_gt1_s, min_gt1); + for (int i = 0; i < 4; ++i) { + memcpy(state->m_coeffFracBits[state_offset + i], state->m_gtxFracBitsArray[sum_gt1_s[i]], sizeof(state->m_coeffFracBits[0])); + } + } + else { + for (int i = 0; i < 4; i++) { + updateStateEOS( + ctxs, + scan_pos, + cg_pos, + sigCtxOffsetNext, + gtxCtxOffsetNext, + width_in_sbb, + height_in_sbb, + next_sbb_right, + next_sbb_below, + decisions, + i); + } + } +} + static INLINE void updateStateEOS( context_store* ctxs, @@ -1215,7 +1535,7 @@ static INLINE void updateStateEOS( memset(state->m_absLevelsAndCtxInit[curr_state_offset], 0, 16 * sizeof(uint8_t)); } uint8_t* temp = (uint8_t*)(&state->m_absLevelsAndCtxInit[curr_state_offset][scan_pos & 15]); - *temp = (uint8_t)MIN(255, decisions->absLevel[decision_id]); + *temp = (uint8_t)MIN(32, decisions->absLevel[decision_id]); update_common_context(ctxs, state->m_commonCtx, scan_pos, cg_pos, width_in_sbb, height_in_sbb, next_sbb_right, next_sbb_below, prvState, ctxs->m_curr_state_offset + decision_id); @@ -1925,7 +2245,7 @@ static void xDecideAndUpdate( if (scan_pos) { if (!(scan_pos & 15)) { - SWAP(ctxs->m_common_context.m_currSbbCtx, ctxs->m_common_context.m_prevSbbCtx, SbbCtx*); + SWAP(ctxs->m_common_context.m_curr_sbb_ctx_offset, ctxs->m_common_context.m_prev_sbb_ctx_offset, int); updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 0); updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 1); updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 2); @@ -1933,6 +2253,7 @@ static void xDecideAndUpdate( memcpy(decisions->prevId + 4, decisions->prevId, 4 * sizeof(int32_t)); memcpy(decisions->absLevel + 4, decisions->absLevel, 4 * sizeof(int32_t)); memcpy(decisions->rdCost + 4, decisions->rdCost, 4 * sizeof(int64_t)); + printf("\n"); } else if (!zeroOut) { update_states_avx2(ctxs, next_nb_info_ssb.num, scan_pos, decisions, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], next_nb_info_ssb, 4, false); /* updateState(ctxs, next_nb_info_ssb.num, scan_pos, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 0);