diff --git a/src/strategies/avx2/picture-avx2.c b/src/strategies/avx2/picture-avx2.c index 6bb729a9..8680590a 100644 --- a/src/strategies/avx2/picture-avx2.c +++ b/src/strategies/avx2/picture-avx2.c @@ -233,7 +233,9 @@ static unsigned satd_4x4_8bit_avx2(const uint8_t *org, const uint8_t *cur) row3 = _mm_add_epi16(row3, _mm_shuffle_epi32(row3, _MM_SHUFFLE(0, 1, 0, 1) )); row3 = _mm_add_epi16(row3, _mm_shufflelo_epi16(row3, _MM_SHUFFLE(0, 1, 0, 1) )); + const int dc1 = abs(_mm_extract_epi16(row2, 0)); unsigned sum = _mm_extract_epi16(row3, 0); + sum -= dc1 - (dc1 >> 2); unsigned satd = (sum + 1) >> 1; return satd; @@ -280,10 +282,16 @@ static void satd_8bit_4x4_dual_avx2( row3 = _mm256_add_epi16(row3, _mm256_shuffle_epi32(row3, _MM_SHUFFLE(0, 1, 0, 1) )); row3 = _mm256_add_epi16(row3, _mm256_shufflelo_epi16(row3, _MM_SHUFFLE(0, 1, 0, 1) )); + const int16_t temp2 = _mm256_extract_epi16(row2, 0); + const int dc1 = abs(temp2); unsigned sum1 = _mm_extract_epi16(_mm256_castsi256_si128(row3), 0); + sum1 -= dc1 - (dc1 >> 2); sum1 = (sum1 + 1) >> 1; + const int16_t temp3 = _mm256_extract_epi16(row2, 8); + const int dc2 = abs(temp3); unsigned sum2 = _mm_extract_epi16(_mm256_extracti128_si256(row3, 1), 0); + sum2 -= dc2 - (dc2 >> 2); sum2 = (sum2 + 1) >> 1; satds_out[0] = sum1; @@ -522,6 +530,13 @@ static void uvg_satd_8bit_8x8_general_dual_avx2(const uint8_t * buf1, unsigned s sum_block_dual_avx2(temp, sum0, sum1); + const int16_t temp2 = _mm256_extract_epi16(temp[0], 0); + const int dc1 = abs(temp2); + const int16_t temp3 = _mm256_extract_epi16(temp[0], 8); + const int dc2 = abs(temp3); + *sum0 -= dc1 - (dc1 >> 2); + *sum1 -= dc2 - (dc2 >> 2); + *sum0 = (*sum0 + 2) >> 2; *sum1 = (*sum1 + 2) >> 2; } @@ -558,6 +573,9 @@ static unsigned satd_8x8_subblock_8bit_avx2(const uint8_t * buf1, unsigned strid unsigned sad = sum_block_avx2(temp); + const int dc1 = abs(_mm_extract_epi16(temp[0], 0)); + sad -= dc1 - (dc1 >> 2); + unsigned result = (sad + 2) >> 2; return result; } diff --git a/tests/tests_main.c b/tests/tests_main.c index 8b45be99..a0a0dd80 100644 --- a/tests/tests_main.c +++ b/tests/tests_main.c @@ -38,7 +38,7 @@ GREATEST_MAIN_DEFS(); #if UVG_BIT_DEPTH == 8 extern SUITE(sad_tests); extern SUITE(intra_sad_tests); -// extern SUITE(satd_tests); +extern SUITE(satd_tests); extern SUITE(speed_tests); extern SUITE(dct_tests); extern SUITE(mts_tests); @@ -56,7 +56,7 @@ int main(int argc, char **argv) #if UVG_BIT_DEPTH == 8 RUN_SUITE(sad_tests); RUN_SUITE(intra_sad_tests); - // RUN_SUITE(satd_tests); + RUN_SUITE(satd_tests); RUN_SUITE(dct_tests); RUN_SUITE(mts_tests);