diff --git a/src/strategies/avx2/intra-avx2.c b/src/strategies/avx2/intra-avx2.c index 46de8ced..afe30ac5 100644 --- a/src/strategies/avx2/intra-avx2.c +++ b/src/strategies/avx2/intra-avx2.c @@ -973,16 +973,35 @@ static void kvz_pdpc_planar_dc_avx2( { assert(mode == 0 || mode == 1); // planar or DC + __m256i shuf_mask_byte = _mm256_setr_epi8( + 0, -1, 0, -1, 0, -1, 0, -1, + 1, -1, 1, -1, 1, -1, 1, -1, + 2, -1, 2, -1, 2, -1, 2, -1, + 3, -1, 3, -1, 3, -1, 3, -1 + ); + + __m256i shuf_mask_word = _mm256_setr_epi8( + 0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, + 6, 7, 6, 7, 6, 7, 6, 7 + ); + // TODO: replace latter log2_width with log2_height const int scale = ((log2_width - 2 + log2_width - 2 + 2) >> 2); - int16_t w[LCU_WIDTH]; - int16_t left[4][4]; - int16_t top [4][4]; - // Same weights regardless of axis, compute once - for (int i = 0; i < width; ++i) { - w[i] = 32 >> MIN(31, ((i << 1) >> scale)); + int16_t w[LCU_WIDTH]; + for (int i = 0; i < width; i += 4) { + __m128i base = _mm_set1_epi32(i); + __m128i offs = _mm_setr_epi32(0, 1, 2, 3); + __m128i idxs = _mm_add_epi32(base, offs); + __m128i unclipped = _mm_slli_epi32(idxs, 1); + unclipped = _mm_srli_epi32(unclipped, scale); + __m128i clipped = _mm_min_epi32( _mm_set1_epi32(31), unclipped); + __m128i weights = _mm_srlv_epi32(_mm_set1_epi32(32), clipped); + weights = _mm_packus_epi32(weights, weights); + _mm_storel_epi64((__m128i*)&w[i], weights); } // Process in 4x4 blocks @@ -990,21 +1009,43 @@ static void kvz_pdpc_planar_dc_avx2( for (int y = 0; y < width; y += 4) { for (int x = 0; x < width; x += 4) { - for (int yy = 0; yy < 4; ++yy) { - for (int xx = 0; xx < 4; ++xx) { - left[yy][xx] = used_ref->left[(y + yy) + 1]; - top [yy][xx] = used_ref->top [(x + xx) + 1]; - } - } + uint32_t dw_left; + uint32_t dw_top; + memcpy(&dw_left, &used_ref->left[y + 1], sizeof(dw_left)); + memcpy(&dw_top , &used_ref->top [x + 1], sizeof(dw_top)); + __m256i vleft = _mm256_set1_epi32(dw_left); + __m256i vtop = _mm256_set1_epi32(dw_top); + vleft = _mm256_shuffle_epi8(vleft, shuf_mask_byte); + vtop = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(vtop)); - for (int yy = 0; yy < 4; ++yy) { - for (int xx = 0; xx < 4; ++xx) { - dst[(x + xx) + (y + yy) * width] += (( - w[(x + xx)] * (left[yy][xx] - dst[(x + xx) + (y + yy) * width]) + - w[(y + yy)] * (top [yy][xx] - dst[(x + xx) + (y + yy) * width]) + - 32) >> 6); - } - } + __m128i vseq = _mm_setr_epi32(0, 1, 2, 3); + __m128i vidx = _mm_slli_epi32(vseq, log2_width); + __m128i vdst = _mm_i32gather_epi32((uint32_t*)(dst + y * width + x), vidx, 1); + __m256i vdst16 = _mm256_cvtepu8_epi16(vdst); + uint64_t quad_wL; + uint64_t quad_wT; + memcpy(&quad_wL, &w[x], sizeof(quad_wL)); + memcpy(&quad_wT, &w[y], sizeof(quad_wT)); + __m256i vwL = _mm256_set1_epi64x(quad_wL); + __m256i vwT = _mm256_set1_epi64x(quad_wT); + vwT = _mm256_shuffle_epi8(vwT, shuf_mask_word); + __m256i diff_left = _mm256_sub_epi16(vleft, vdst16); + __m256i diff_top = _mm256_sub_epi16(vtop , vdst16); + __m256i prod_left = _mm256_mullo_epi16(vwL, diff_left); + __m256i prod_top = _mm256_mullo_epi16(vwT, diff_top); + __m256i accu = _mm256_add_epi16(prod_left, prod_top); + accu = _mm256_add_epi16(accu, _mm256_set1_epi16(32)); + accu = _mm256_srai_epi16(accu, 6); + accu = _mm256_add_epi16(vdst16, accu); + + __m128i lo = _mm256_castsi256_si128(accu); + __m128i hi = _mm256_extracti128_si256(accu, 1); + vdst = _mm_packus_epi16(lo, hi); + + *(uint32_t*)(dst + (y + 0) * width + x) = _mm_extract_epi32(vdst, 0); + *(uint32_t*)(dst + (y + 1) * width + x) = _mm_extract_epi32(vdst, 1); + *(uint32_t*)(dst + (y + 2) * width + x) = _mm_extract_epi32(vdst, 2); + *(uint32_t*)(dst + (y + 3) * width + x) = _mm_extract_epi32(vdst, 3); } } }