[avx2] WIP update_state_eos_avx2

This commit is contained in:
Joose Sainio 2023-04-17 13:52:42 +03:00
parent c56350b8d6
commit 9e27b4056a

View file

@ -107,9 +107,11 @@ typedef struct
const NbInfoOut* m_nbInfo;
uint32_t m_sbbFlagBits[2][2];
SbbCtx m_allSbbCtx[8];
SbbCtx* m_currSbbCtx;
SbbCtx* m_prevSbbCtx;
uint8_t m_memory[8 * (TR_MAX_WIDTH * TR_MAX_WIDTH + 1024)];
int m_curr_sbb_ctx_offset;
int m_prev_sbb_ctx_offset;
uint8_t sbb_memory[8 * 1024];
uint8_t level_memory[8* TR_MAX_WIDTH * TR_MAX_WIDTH];
int num_coeff;
} common_context;
@ -447,14 +449,15 @@ static void reset_common_context(common_context* ctx, const rate_estimator * rat
{
//memset(&ctx->m_nbInfo, 0, sizeof(ctx->m_nbInfo));
memcpy(&ctx->m_sbbFlagBits, &rate_estimator->m_sigSbbFracBits, sizeof(rate_estimator->m_sigSbbFracBits));
const int chunkSize = numSbb + num_coeff;
uint8_t* nextMem = ctx->m_memory;
for (int k = 0; k < 8; k++, nextMem += chunkSize) {
ctx->m_allSbbCtx[k].sbbFlags = nextMem;
ctx->m_allSbbCtx[k].levels = nextMem + numSbb;
uint8_t* next_sbb_memory = ctx->sbb_memory;
uint8_t* next_level_memory = ctx->level_memory;
for (int k = 0; k < 8; k++, next_sbb_memory += numSbb, next_level_memory += num_coeff) {
ctx->m_allSbbCtx[k].sbbFlags = next_sbb_memory;
ctx->m_allSbbCtx[k].levels = next_level_memory;
}
ctx->m_currSbbCtx = &ctx->m_allSbbCtx[0];
ctx->m_prevSbbCtx = &ctx->m_allSbbCtx[4];
ctx->m_curr_sbb_ctx_offset = 0;
ctx->m_prev_sbb_ctx_offset = 4;
ctx->num_coeff = num_coeff;
}
static void init_rate_esimator(rate_estimator * rate_estimator, const cabac_data_t * const ctx, color_t color)
@ -1121,12 +1124,12 @@ static INLINE void update_common_context(
const int curr_state)
{
const uint32_t numSbb = width_in_sbb * height_in_sbb;
uint8_t* sbbFlags = cc->m_currSbbCtx[curr_state & 3].sbbFlags;
uint8_t* levels = cc->m_currSbbCtx[curr_state & 3].levels;
uint8_t* sbbFlags = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset + (curr_state & 3)].sbbFlags;
uint8_t* levels = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset + (curr_state & 3)].levels;
size_t setCpSize = cc->m_nbInfo[scan_pos - 1].maxDist * sizeof(uint8_t);
if (prev_state != -1 && ctxs->m_allStates.m_refSbbCtxId[prev_state] >= 0) {
memcpy(sbbFlags, cc->m_prevSbbCtx[ctxs->m_allStates.m_refSbbCtxId[prev_state]].sbbFlags, numSbb * sizeof(uint8_t));
memcpy(levels + scan_pos, cc->m_prevSbbCtx[ctxs->m_allStates.m_refSbbCtxId[prev_state]].levels + scan_pos, setCpSize);
memcpy(sbbFlags, cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset + ctxs->m_allStates.m_refSbbCtxId[prev_state]].sbbFlags, numSbb * sizeof(uint8_t));
memcpy(levels + scan_pos, cc->m_allSbbCtx[cc->m_prev_sbb_ctx_offset + ctxs->m_allStates.m_refSbbCtxId[prev_state]].levels + scan_pos, setCpSize);
}
else {
memset(sbbFlags, 0, numSbb * sizeof(uint8_t));
@ -1181,6 +1184,323 @@ static INLINE void update_common_context(
memset(ctxs->m_allStates.m_absLevelsAndCtxInit[curr_state], 0, 16 * sizeof(uint8_t));
}
static INLINE void updateStateEOS(
context_store* ctxs,
const uint32_t scan_pos,
const uint32_t cg_pos,
const uint32_t sigCtxOffsetNext,
const uint32_t gtxCtxOffsetNext,
const uint32_t width_in_sbb,
const uint32_t height_in_sbb,
const uint32_t next_sbb_right,
const uint32_t next_sbb_below,
const Decision* decisions,
int decision_id);
static void update_state_eos_avx2(context_store* ctxs, const uint32_t scan_pos, const uint32_t cg_pos,
const uint32_t sigCtxOffsetNext, const uint32_t gtxCtxOffsetNext,
const uint32_t width_in_sbb, const uint32_t height_in_sbb,
const uint32_t next_sbb_right, const uint32_t next_sbb_below,
const Decision* decisions)
{
all_depquant_states* state = &ctxs->m_allStates;
bool all_above_minus_two = true;
bool all_between_zero_and_three = true;
bool all_above_four = true;
int state_offset = ctxs->m_curr_state_offset;
__m256i rd_cost = _mm256_loadu_epi64(decisions->rdCost);
_mm256_storeu_epi64(&ctxs->m_allStates.m_rdCost[state_offset], rd_cost);
for (int i = 0; i < 4; ++i) {
all_above_minus_two &= decisions->prevId[i] > -2;
all_between_zero_and_three &= decisions->prevId[i] >= 0 && decisions->prevId[i] < 4;
all_above_four &= decisions->prevId[i] >= 4;
}
if (all_above_minus_two) {
bool all_have_previous_state = true;
__m128i prev_state;
__m128i abs_level = _mm_loadu_epi32(decisions->absLevel);
if (all_above_four) {
prev_state = _mm_set1_epi32(ctxs->m_skip_state_offset);
prev_state = _mm_add_epi32(
prev_state,
_mm_sub_epi32(
_mm_loadu_epi32(decisions->prevId),
_mm_set1_epi32(4)
)
);
memset(&state->m_numSigSbb[state_offset], 0, 4);
for (int i = 0; i < 4; ++i) {
memset(state->m_absLevelsAndCtxInit[state_offset + i], 0, 16 * sizeof(uint8_t));
}
} else if (all_between_zero_and_three) {
prev_state = _mm_set1_epi32(ctxs->m_skip_state_offset);
prev_state = _mm_add_epi32(
prev_state,
_mm_sub_epi32(
_mm_loadu_epi32(decisions->prevId),
_mm_set1_epi32(4)
)
);
__m128i num_sig_sbb = _mm_i32gather_epi32(&state->m_numSigSbb[state_offset], prev_state, 1);
num_sig_sbb = _mm_and_epi32(num_sig_sbb, _mm_set1_epi32(0xff));
num_sig_sbb = _mm_and_epi32(
num_sig_sbb,
_mm_max_epi32(abs_level, _mm_set1_epi32(1))
);
__m128i control = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1);
num_sig_sbb = _mm_shuffle_epi8(num_sig_sbb, control);
int num_sig_sbb_s = _mm_extract_epi32(num_sig_sbb, 0);
memcpy(&state->m_refSbbCtxId[state_offset], &num_sig_sbb_s, 4);
int32_t prev_state_scalar[4];
_mm_storeu_epi32(prev_state_scalar, prev_state);
for (int i = 0; i < 4; ++i) {
memcpy(state->m_absLevelsAndCtxInit[state_offset + i], state->m_absLevelsAndCtxInit[prev_state_scalar[i]], 16 * sizeof(uint8_t));
}
} else {
int prev_state_s[4] = {-1, -1, -1, -1};
for (int i = 0; i < 4; ++i) {
const int decision_id = i;
const int curr_state_offset = state_offset + i;
if (decisions->prevId[decision_id] >= 4) {
prev_state_s[i] = ctxs->m_skip_state_offset + (decisions->prevId[decision_id] - 4);
state->m_numSigSbb[curr_state_offset] = 0;
memset(state->m_absLevelsAndCtxInit[curr_state_offset], 0, 16 * sizeof(uint8_t));
} else if (decisions->prevId[decision_id] >= 0) {
prev_state_s[i] = ctxs->m_prev_state_offset + decisions->prevId[decision_id];
state->m_numSigSbb[curr_state_offset] = state->m_numSigSbb[prev_state_s[i]] + !!decisions->absLevel[decision_id];
memcpy(state->m_absLevelsAndCtxInit[curr_state_offset], state->m_absLevelsAndCtxInit[prev_state_s[i]], 16 * sizeof(uint8_t));
} else {
state->m_numSigSbb[curr_state_offset] = 1;
memset(state->m_absLevelsAndCtxInit[curr_state_offset], 0, 16 * sizeof(uint8_t));
all_have_previous_state = false;
}
}
prev_state = _mm_loadu_epi32(prev_state_s);
}
uint32_t level_offset = scan_pos & 15;
__m128i max_abs = _mm_min_epi32(abs_level, _mm_set1_epi32(32));
uint32_t max_abs_s[4];
_mm_storeu_epi32(max_abs_s, max_abs);
for (int i = 0; i < 4; ++i) {
uint8_t* levels = (uint8_t*)state->m_absLevelsAndCtxInit[state_offset + i];
levels[level_offset] = max_abs_s[i];
}
// Update common context
__m128i last;
{
const uint32_t numSbb = width_in_sbb * height_in_sbb;
common_context* cc = &ctxs->m_common_context;
size_t setCpSize = cc->m_nbInfo[scan_pos - 1].maxDist * sizeof(uint8_t);
int previous_state_array[4];
_mm_storeu_epi32(previous_state_array, prev_state);
for (int curr_state = 0; curr_state < 4; ++curr_state) {
uint8_t* sbbFlags = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset + (curr_state)].sbbFlags;
uint8_t* levels = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset + (curr_state)].levels;
const int p_state = previous_state_array[curr_state];
if (p_state != -1 && ctxs->m_allStates.m_refSbbCtxId[p_state] >= 0) {
const int prev_sbb = cc->m_prev_sbb_ctx_offset + ctxs->m_allStates.m_refSbbCtxId[p_state];
memcpy(sbbFlags, cc->m_allSbbCtx[prev_sbb].sbbFlags, numSbb * sizeof(uint8_t));
memcpy(levels + scan_pos, cc->m_allSbbCtx[prev_sbb].levels + scan_pos, setCpSize);
} else {
memset(sbbFlags, 0, numSbb * sizeof(uint8_t));
memset(levels + scan_pos, 0, setCpSize);
}
sbbFlags[cg_pos] = !!ctxs->m_allStates.m_numSigSbb[curr_state + state_offset];
memcpy(levels + scan_pos, ctxs->m_allStates.m_absLevelsAndCtxInit[curr_state + state_offset], 16 * sizeof(uint8_t));
}
__m128i sbb_offsets = _mm_set_epi32(3 * numSbb, 2 * numSbb, 1 * numSbb, 0);
__m128i next_sbb_right_m = _mm_set1_epi32(next_sbb_right);
__m128i sbb_offsets_right = _mm_add_epi32(sbb_offsets, next_sbb_right_m);
__m128i sbb_right = next_sbb_right ? _mm_i32gather_epi32(&cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].sbbFlags, sbb_offsets_right, 1) : _mm_set1_epi32(0);
__m128i sbb_offsets_below = _mm_add_epi32(sbb_offsets, _mm_set1_epi32(next_sbb_below));
__m128i sbb_below = next_sbb_right ? _mm_i32gather_epi32(&cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].sbbFlags, sbb_offsets_below, 1) : _mm_set1_epi32(0);
__m128i sig_sbb = _mm_or_epi32(sbb_right, sbb_below);
sig_sbb = _mm_max_epi32(sig_sbb, _mm_set1_epi32(1));
__m256i sbb_frac_bits = _mm256_i32gather_epi64(cc->m_sbbFlagBits, sig_sbb, 8);
_mm256_storeu_epi64(state->m_sbbFracBits[state_offset], sbb_frac_bits);
memset(&state->m_numSigSbb[state_offset], 0, 4);
memset(&state->m_goRicePar[state_offset], 0, 4);
uint8_t states[4] = {0, 1, 2, 3};
memcpy(&state->m_refSbbCtxId[state_offset], states, 4);
if (all_have_previous_state) {
__m128i rem_reg_bins = _mm_i32gather_epi32(state->m_remRegBins, prev_state, 4);
_mm_storeu_epi32(&state->m_remRegBins[state_offset], rem_reg_bins);
} else {
const int temp = (state->effWidth * state->effHeight * 28) / 16;
for (int i = 0; i < 4; ++i) {
if (previous_state_array[i] != -1) {
state->m_remRegBins[i + state_offset] = state->m_remRegBins[previous_state_array[i]];
} else {
state->m_remRegBins[i + state_offset] = temp;
}
}
}
const int scanBeg = scan_pos - 16;
const NbInfoOut* nbOut = cc->m_nbInfo + scanBeg;
const uint8_t* absLevels = cc->m_allSbbCtx[cc->m_curr_sbb_ctx_offset].levels + scanBeg;
__m128i levels_offsets = _mm_set_epi32(cc->num_coeff * 3, cc->num_coeff * 2, cc->num_coeff * 1, 0);
__m128i first_byte = _mm_set1_epi32(0xff);
__m128i ones = _mm_set1_epi32(1);
__m128i fours = _mm_set1_epi32(4);
__m256i all[4];
uint64_t temp[4];
for (int id = 0; id < 16; id++, nbOut++) {
if (nbOut->num == 0) {
temp[id % 4] = 0;
if (id % 4 == 3) {
all[0] = _mm256_loadu_epi64(temp);
}
continue;
}
__m128i sum_abs = _mm_set1_epi32(0);
__m128i sum_abs_1 = _mm_set1_epi32(0);
__m128i sum_num = _mm_set1_epi32(0);
switch (nbOut->num) {
case 5:
{
__m128i offset = _mm_add_epi32(levels_offsets, _mm_set1_epi32(nbOut->outPos[4]));
__m128i t = _mm_i32gather_epi32(absLevels, offset, 1);
t = _mm_and_epi32(t, first_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
sum_num = _mm_add_epi32(sum_num, _mm_max_epi32(t, ones));
__m128i min_t = _mm_min_epi32(
t,
_mm_add_epi32(
fours,
_mm_and_epi32(t, ones)
)
);
sum_abs_1 = _mm_add_epi32(sum_abs_1, min_t);
}
case 4: {
__m128i offset = _mm_add_epi32(levels_offsets, _mm_set1_epi32(nbOut->outPos[3]));
__m128i t = _mm_i32gather_epi32(absLevels, offset, 1);
t = _mm_and_epi32(t, first_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
sum_num = _mm_add_epi32(sum_num, _mm_max_epi32(t, ones));
__m128i min_t = _mm_min_epi32(
t,
_mm_add_epi32(
fours,
_mm_and_epi32(t, ones)));
sum_abs_1 = _mm_add_epi32(sum_abs_1, min_t);
}
case 3: {
__m128i offset = _mm_add_epi32(levels_offsets, _mm_set1_epi32(nbOut->outPos[2]));
__m128i t = _mm_i32gather_epi32(absLevels, offset, 1);
t = _mm_and_epi32(t, first_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
sum_num = _mm_add_epi32(sum_num, _mm_max_epi32(t, ones));
__m128i min_t = _mm_min_epi32(
t,
_mm_add_epi32(
fours,
_mm_and_epi32(t, ones)));
sum_abs_1 = _mm_add_epi32(sum_abs_1, min_t);
}
case 2: {
__m128i offset = _mm_add_epi32(levels_offsets, _mm_set1_epi32(nbOut->outPos[1]));
__m128i t = _mm_i32gather_epi32(absLevels, offset, 1);
t = _mm_and_epi32(t, first_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
sum_num = _mm_add_epi32(sum_num, _mm_max_epi32(t, ones));
__m128i min_t = _mm_min_epi32(
t,
_mm_add_epi32(
fours,
_mm_and_epi32(t, ones)));
sum_abs_1 = _mm_add_epi32(sum_abs_1, min_t);
}
case 1: {
__m128i offset = _mm_add_epi32(levels_offsets, _mm_set1_epi32(nbOut->outPos[0]));
__m128i t = _mm_i32gather_epi32(absLevels, offset, 1);
t = _mm_and_epi32(t, first_byte);
sum_abs = _mm_add_epi32(sum_abs, t);
sum_num = _mm_add_epi32(sum_num, _mm_max_epi32(t, ones));
__m128i min_t = _mm_min_epi32(
t,
_mm_add_epi32(
fours,
_mm_and_epi32(t, ones)));
sum_abs_1 = _mm_add_epi32(sum_abs_1, min_t);
}
break;
default:
assert(0);
}
sum_abs_1 = _mm_slli_epi32(sum_abs_1, 3);
sum_abs = _mm_slli_epi32(_mm_min_epi32(_mm_set1_epi32(127), sum_abs), 8);
__m128i template_ctx_init = _mm_add_epi32(sum_num, sum_abs);
_mm_add_epi32(template_ctx_init, sum_abs_1);
__m128i shuffle_mask = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 0, 0, 0, 0, 0, 0, 0, 0);
__m128i shuffled_template_ctx_init = _mm_shuffle_epi8(template_ctx_init, shuffle_mask);
temp[id % 4] = _mm_extract_epi64(shuffled_template_ctx_init, 0);
if (id %4 == 3) {
all[0] = _mm256_loadu_epi64(temp);
last = template_ctx_init;
}
}
for (int i = 0; i < 4; ++i) {
memset(state->m_absLevelsAndCtxInit[state_offset + i], 0, 16);
}
}
__m128i sum_num = _mm_and_epi32(last, _mm_set1_epi32(7));
__m128i sum_abs1 = _mm_and_epi32(
_mm_srli_epi32(last, 3),
_mm_set1_epi32(31));
__m128i sum_abs_min = _mm_min_epi32(
_mm_set1_epi32(3),
_mm_srli_epi32(
_mm_add_epi32(sum_abs1, _mm_set1_epi32(1)),
1));
__m128i offsets = _mm_set_epi32(12 * 3, 12 * 2, 12 * 1, 12 * 0);
offsets = _mm_add_epi32(offsets, _mm_set1_epi32(sigCtxOffsetNext));
__m256i sig_frac_bits = _mm256_i32gather_epi64(state->m_sigFracBitsArray[state_offset][0], offsets, 8);
_mm256_storeu_epi64(&state->m_sigFracBits[state_offset][0], sig_frac_bits);
__m128i sum_gt1 = _mm_sub_epi32(sum_abs1, sum_num);
__m128i min_gt1 = _mm_min_epi32(sum_gt1, _mm_set1_epi32(4));
uint32_t sum_gt1_s[4];
_mm_storeu_epi32(sum_gt1_s, min_gt1);
for (int i = 0; i < 4; ++i) {
memcpy(state->m_coeffFracBits[state_offset + i], state->m_gtxFracBitsArray[sum_gt1_s[i]], sizeof(state->m_coeffFracBits[0]));
}
}
else {
for (int i = 0; i < 4; i++) {
updateStateEOS(
ctxs,
scan_pos,
cg_pos,
sigCtxOffsetNext,
gtxCtxOffsetNext,
width_in_sbb,
height_in_sbb,
next_sbb_right,
next_sbb_below,
decisions,
i);
}
}
}
static INLINE void updateStateEOS(
context_store* ctxs,
@ -1215,7 +1535,7 @@ static INLINE void updateStateEOS(
memset(state->m_absLevelsAndCtxInit[curr_state_offset], 0, 16 * sizeof(uint8_t));
}
uint8_t* temp = (uint8_t*)(&state->m_absLevelsAndCtxInit[curr_state_offset][scan_pos & 15]);
*temp = (uint8_t)MIN(255, decisions->absLevel[decision_id]);
*temp = (uint8_t)MIN(32, decisions->absLevel[decision_id]);
update_common_context(ctxs, state->m_commonCtx, scan_pos, cg_pos, width_in_sbb, height_in_sbb, next_sbb_right,
next_sbb_below, prvState, ctxs->m_curr_state_offset + decision_id);
@ -1925,7 +2245,7 @@ static void xDecideAndUpdate(
if (scan_pos) {
if (!(scan_pos & 15)) {
SWAP(ctxs->m_common_context.m_currSbbCtx, ctxs->m_common_context.m_prevSbbCtx, SbbCtx*);
SWAP(ctxs->m_common_context.m_curr_sbb_ctx_offset, ctxs->m_common_context.m_prev_sbb_ctx_offset, int);
updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 0);
updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 1);
updateStateEOS(ctxs, scan_pos, scan_info->cg_pos, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], width_in_sbb, height_in_sbb, scan_info->next_sbb_right, scan_info->next_sbb_below, decisions, 2);
@ -1933,6 +2253,7 @@ static void xDecideAndUpdate(
memcpy(decisions->prevId + 4, decisions->prevId, 4 * sizeof(int32_t));
memcpy(decisions->absLevel + 4, decisions->absLevel, 4 * sizeof(int32_t));
memcpy(decisions->rdCost + 4, decisions->rdCost, 4 * sizeof(int64_t));
printf("\n");
} else if (!zeroOut) {
update_states_avx2(ctxs, next_nb_info_ssb.num, scan_pos, decisions, scan_info->sig_ctx_offset[is_chroma], scan_info->gtx_ctx_offset[is_chroma], next_nb_info_ssb, 4, false);
/* updateState(ctxs, next_nb_info_ssb.num, scan_pos, decisions, sigCtxOffsetNext, gtxCtxOffsetNext, next_nb_info_ssb, 4, false, 0);