Implement variance calculation in integer math

Maybe this is a bit faster than FP, it's not accurate though
This commit is contained in:
Pauli Oikkonen 2020-02-27 17:52:40 +02:00
parent 35c825c75f
commit fc1b91335b

View file

@ -1137,12 +1137,87 @@ static double pixel_var_avx2_largebuf(const kvz_pixel *buf, const uint32_t len)
return var_sum / len_f;
}
// Assumes that u is a power of two
static INLINE uint32_t ilog2(uint32_t u)
{
return _tzcnt_u32(u);
}
// A B C D | E F G H (8x32b)
// ==>
// A+B C+D | E+F G+H (4x64b)
static __m256i hsum_epi32_to_epi64(const __m256i v)
{
const __m256i zero = _mm256_setzero_si256();
__m256i v_shufd = _mm256_shuffle_epi32(v, _MM_SHUFFLE(3, 3, 1, 1));
__m256i sums_32 = _mm256_add_epi32 (v, v_shufd);
__m256i sums_64 = _mm256_blend_epi32 (sums_32, zero, 0xaa);
return sums_64;
}
static double pixel_var_avx2(const kvz_pixel *buf, const uint32_t len)
{
assert(sizeof(*buf) == 1);
assert((len & 31) == 0);
return pixel_var_avx2_largebuf(buf, len);
// Uses Q8.7 numbers to measure mean and deviation, so variances are Q16.14
const uint64_t sum_maxwid = ilog2(len) + (8 * sizeof(*buf));
const __m128i normalize_sum = _mm_cvtsi32_si128(sum_maxwid - 15); // Normalize mean to [0, 32767], so signed 16-bit subtraction never overflows
const __m128i debias_sum = _mm_cvtsi32_si128(1 << (sum_maxwid - 16));
const float varsum_to_f = 1.0f / (float)(1 << (14 + ilog2(len)));
const bool power_of_two = (len & (len - 1)) == 0;
if (sum_maxwid > 32 || sum_maxwid < 15 || !power_of_two) {
return pixel_var_avx2_largebuf(buf, len);
}
const __m256i zero = _mm256_setzero_si256();
const __m256i himask_15 = _mm256_set1_epi16(0x7f00);
size_t i;
__m256i sums = zero;
for (i = 0; i < len; i += 32) {
__m256i curr = _mm256_loadu_si256((const __m256i *)(buf + i));
__m256i curr_sum = _mm256_sad_epu8(curr, zero);
sums = _mm256_add_epi64(sums, curr_sum);
}
__m128i sum_lo = _mm256_castsi256_si128 (sums);
__m128i sum_hi = _mm256_extracti128_si256(sums, 1);
__m128i sum_3 = _mm_add_epi64 (sum_lo, sum_hi);
__m128i sum_4 = _mm_shuffle_epi32 (sum_3, _MM_SHUFFLE(1, 0, 3, 2));
__m128i sum_5 = _mm_add_epi64 (sum_3, sum_4);
__m128i sum_5n = _mm_srl_epi32 (sum_5, normalize_sum);
sum_5n = _mm_add_epi32 (sum_5n, debias_sum);
__m256i sum_n = _mm256_broadcastw_epi16 (sum_5n);
__m256i accum = zero;
for (i = 0; i < len; i += 32) {
__m256i curr = _mm256_loadu_si256((const __m256i *)(buf + i));
__m256i curr0 = _mm256_slli_epi16 (curr, 7);
__m256i curr1 = _mm256_srli_epi16 (curr, 1);
curr0 = _mm256_and_si256 (curr0, himask_15);
curr1 = _mm256_and_si256 (curr1, himask_15);
__m256i dev0 = _mm256_sub_epi16 (curr0, sum_n);
__m256i dev1 = _mm256_sub_epi16 (curr1, sum_n);
__m256i vars0 = _mm256_madd_epi16 (dev0, dev0);
__m256i vars1 = _mm256_madd_epi16 (dev1, dev1);
__m256i varsum = _mm256_add_epi32 (vars0, vars1);
varsum = hsum_epi32_to_epi64(varsum);
accum = _mm256_add_epi64 (accum, varsum);
}
__m256i accum2 = _mm256_permute4x64_epi64(accum, _MM_SHUFFLE(1, 0, 3, 2));
__m256i accum3 = _mm256_add_epi64 (accum, accum2);
__m256i accum4 = _mm256_permute4x64_epi64(accum3, _MM_SHUFFLE(2, 3, 1, 0));
__m256i v_tot = _mm256_add_epi64 (accum3, accum4);
__m128i vt128 = _mm256_castsi256_si128 (v_tot);
uint64_t vars = _mm_cvtsi128_si64 (vt128);
return (float)vars * varsum_to_f;
}
#endif //COMPILE_INTEL_AVX2