diff --git a/src/strategies/avx2/sao-avx2.c b/src/strategies/avx2/sao-avx2.c index 8edb85e4..709c991f 100644 --- a/src/strategies/avx2/sao-avx2.c +++ b/src/strategies/avx2/sao-avx2.c @@ -148,12 +148,18 @@ int kvz_sao_edge_ddistortion_avx2(const kvz_pixel *orig_data, const kvz_pixel *r } -/** - * \param orig_data Original pixel data. 64x64 for luma, 32x32 for chroma. - * \param rec_data Reconstructed pixel data. 64x64 for luma, 32x32 for chroma. - * \param dir_offsets - * \param is_chroma 0 for luma, 1 for chroma. Indicates - */ +static INLINE void accum_count_eo_cat_avx2(__m256i* __restrict v_diff_accum, __m256i* __restrict v_count, __m256i* __restrict v_cat, __m256i* __restrict v_diff, int eo_cat){ + __m256i v_mask = _mm256_cmpeq_epi32(*v_cat, _mm256_set1_epi32(eo_cat)); + *v_diff_accum = _mm256_add_epi32(*v_diff_accum, _mm256_and_si256(*v_diff, v_mask)); + *v_count = _mm256_sub_epi32(*v_count, v_mask); +} + + +#define ACCUM_COUNT_EO_CAT_AVX2(EO_CAT, V_CAT) \ + \ + accum_count_eo_cat_avx2(&(v_diff_accum[ EO_CAT ]), &(v_count[ EO_CAT ]), &V_CAT , &v_diff, EO_CAT); + + void kvz_calc_sao_edge_dir_avx2(const kvz_pixel *orig_data, const kvz_pixel *rec_data, int eo_class, int block_width, int block_height, int cat_sum_cnt[2][NUM_SAO_EDGE_CATEGORIES]) @@ -161,22 +167,84 @@ void kvz_calc_sao_edge_dir_avx2(const kvz_pixel *orig_data, const kvz_pixel *rec int y, x; vector2d_t a_ofs = g_sao_edge_offsets[eo_class][0]; vector2d_t b_ofs = g_sao_edge_offsets[eo_class][1]; - // Arrays orig_data and rec_data are quarter size for chroma. // Don't sample the edge pixels because this function doesn't have access to // their neighbours. + + __m256i v_diff_accum[NUM_SAO_EDGE_CATEGORIES] = { { 0 } }; + __m256i v_count[NUM_SAO_EDGE_CATEGORIES] = { { 0 } }; + for (y = 1; y < block_height - 1; ++y) { - for (x = 1; x < block_width - 1; ++x) { + + //Calculation for 8 pixels per round + for (x = 1; x < block_width - 8; x += 8) { const kvz_pixel *c_data = &rec_data[y * block_width + x]; - kvz_pixel a = c_data[a_ofs.y * block_width + a_ofs.x]; - kvz_pixel c = c_data[0]; - kvz_pixel b = c_data[b_ofs.y * block_width + b_ofs.x]; - int eo_cat = sao_calc_eo_cat(a, b, c); + __m128i v_c_data = _mm_loadl_epi64((__m128i* __restrict)c_data); + __m128i v_a = _mm_loadl_epi64((__m128i* __restrict)(&c_data[a_ofs.y * block_width + a_ofs.x])); + __m128i v_c = v_c_data; + __m128i v_b = _mm_loadl_epi64((__m128i* __restrict)(&c_data[b_ofs.y * block_width + b_ofs.x])); - cat_sum_cnt[0][eo_cat] += orig_data[y * block_width + x] - c; - cat_sum_cnt[1][eo_cat] += 1; + __m256i v_cat = _mm256_cvtepu8_epi32(sao_calc_eo_cat_avx2(&v_a, &v_b, &v_c)); + + __m256i v_diff = _mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i* __restrict)&(orig_data[y * block_width + x]))); + v_diff = _mm256_sub_epi32(v_diff, _mm256_cvtepu8_epi32(v_c)); + + //Accumulate differences and occurrences for each category + ACCUM_COUNT_EO_CAT_AVX2(SAO_EO_CAT0, v_cat); + ACCUM_COUNT_EO_CAT_AVX2(SAO_EO_CAT1, v_cat); + ACCUM_COUNT_EO_CAT_AVX2(SAO_EO_CAT2, v_cat); + ACCUM_COUNT_EO_CAT_AVX2(SAO_EO_CAT3, v_cat); + ACCUM_COUNT_EO_CAT_AVX2(SAO_EO_CAT4, v_cat); } + + //Handle last 6 pixels separately to prevent reading over boundary + const kvz_pixel *c_data = &rec_data[y * block_width + x]; + __m128i v_c_data = load_6_pixels(c_data); + const kvz_pixel* a_ptr = &c_data[a_ofs.y * block_width + a_ofs.x]; + const kvz_pixel* b_ptr = &c_data[b_ofs.y * block_width + b_ofs.x]; + __m128i v_a = load_6_pixels(a_ptr); + __m128i v_c = v_c_data; + __m128i v_b = load_6_pixels(b_ptr); + + __m256i v_cat = _mm256_cvtepu8_epi32(sao_calc_eo_cat_avx2(&v_a, &v_b, &v_c)); + + //Set the last two elements to a non-existing category to cause + //the accumulate-count macro to discard those values. + __m256i v_mask = _mm256_setr_epi32(0, 0, 0, 0, 0, 0, -1, -1); + v_cat = _mm256_or_si256(v_cat, v_mask); + + const kvz_pixel* orig_ptr = &(orig_data[y * block_width + x]); + __m256i v_diff = _mm256_cvtepu8_epi32(load_6_pixels(orig_ptr)); + v_diff = _mm256_sub_epi32(v_diff, _mm256_cvtepu8_epi32(v_c)); + + //Accumulate differences and occurrences for each category + ACCUM_COUNT_EO_CAT_AVX2(SAO_EO_CAT0, v_cat); + ACCUM_COUNT_EO_CAT_AVX2(SAO_EO_CAT1, v_cat); + ACCUM_COUNT_EO_CAT_AVX2(SAO_EO_CAT2, v_cat); + ACCUM_COUNT_EO_CAT_AVX2(SAO_EO_CAT3, v_cat); + ACCUM_COUNT_EO_CAT_AVX2(SAO_EO_CAT4, v_cat); + } + + for (int eo_cat = 0; eo_cat < NUM_SAO_EDGE_CATEGORIES; ++eo_cat) { + int accum = 0; + int count = 0; + + //Full horizontal sum of accumulated values + v_diff_accum[eo_cat] = _mm256_add_epi32(v_diff_accum[eo_cat], _mm256_castsi128_si256(_mm256_extracti128_si256(v_diff_accum[eo_cat], 1))); + v_diff_accum[eo_cat] = _mm256_add_epi32(v_diff_accum[eo_cat], _mm256_shuffle_epi32(v_diff_accum[eo_cat], KVZ_PERMUTE(2, 3, 0, 1))); + v_diff_accum[eo_cat] = _mm256_add_epi32(v_diff_accum[eo_cat], _mm256_shuffle_epi32(v_diff_accum[eo_cat], KVZ_PERMUTE(1, 0, 1, 0))); + accum += _mm_cvtsi128_si32(_mm256_castsi256_si128(v_diff_accum[eo_cat])); + + //Full horizontal sum of accumulated values + v_count[eo_cat] = _mm256_add_epi32(v_count[eo_cat], _mm256_castsi128_si256(_mm256_extracti128_si256(v_count[eo_cat], 1))); + v_count[eo_cat] = _mm256_add_epi32(v_count[eo_cat], _mm256_shuffle_epi32(v_count[eo_cat], KVZ_PERMUTE(2, 3, 0, 1))); + v_count[eo_cat] = _mm256_add_epi32(v_count[eo_cat], _mm256_shuffle_epi32(v_count[eo_cat], KVZ_PERMUTE(1, 0, 1, 0))); + count += _mm_cvtsi128_si32(_mm256_castsi256_si128(v_count[eo_cat])); + + cat_sum_cnt[0][eo_cat] += accum; + cat_sum_cnt[1][eo_cat] += count; + } }