Implement streamlined generic 16x16 matrix multiply

It can't be this fast for real, can it?
This commit is contained in:
Pauli Oikkonen 2019-06-06 13:54:25 +03:00
parent beb85ce9d6
commit e463d27f22

View file

@ -528,61 +528,48 @@ static void matmul_16x16_a_bt(const int16_t *a, const int16_t *b_t, int16_t *out
// destination for the result and the shift value for clipping.
static void mul_clip_matrix_16x16_avx2(const int16_t *left, const int16_t *right, int16_t *dst, const int32_t shift)
{
int i, j;
__m256i row[4], accu[16][2], even, odd;
const int32_t add = 1 << (shift - 1);
const __m256i debias = _mm256_set1_epi32(add);
const int32_t stride = 8;
for (int32_t dry = 0; dry < 16; dry += 2) {
__m256i accum1 = _mm256_setzero_si256();
__m256i accum2 = _mm256_setzero_si256();
__m256i accum3 = _mm256_setzero_si256();
__m256i accum4 = _mm256_setzero_si256();
const int32_t add = 1 << (shift - 1);
for (int32_t lx = 0; lx < 16; lx += 2) {
const int32_t *curr_left_up = (const int32_t *)(left + (dry + 0) * 16 + lx);
const int32_t *curr_left_dn = (const int32_t *)(left + (dry + 1) * 16 + lx);
row[0] = _mm256_loadu_si256((__m256i*) right);
row[1] = _mm256_loadu_si256((__m256i*) right + 1);
row[2] = _mm256_unpacklo_epi16(row[0], row[1]);
row[3] = _mm256_unpackhi_epi16(row[0], row[1]);
row[0] = _mm256_permute2x128_si256(row[2], row[3], 0 + 32);
row[1] = _mm256_permute2x128_si256(row[2], row[3], 1 + 48);
__m256i left_slice_lo = _mm256_set1_epi32(*curr_left_up);
__m256i left_slice_hi = _mm256_set1_epi32(*curr_left_dn);
for (i = 0; i < 16; i += 2) {
__m256i right_up = _mm256_loadu_si256((const __m256i *)right + lx + 0);
__m256i right_dn = _mm256_loadu_si256((const __m256i *)right + lx + 1);
even = _mm256_set1_epi32(((int32_t*)left)[stride * i]);
accu[i][0] = _mm256_madd_epi16(even, row[0]);
accu[i][1] = _mm256_madd_epi16(even, row[1]);
__m256i right_slices_lo = _mm256_unpacklo_epi16(right_up, right_dn);
__m256i right_slices_hi = _mm256_unpackhi_epi16(right_up, right_dn);
odd = _mm256_set1_epi32(((int32_t*)left)[stride * (i + 1)]);
accu[i + 1][0] = _mm256_madd_epi16(odd, row[0]);
accu[i + 1][1] = _mm256_madd_epi16(odd, row[1]);
}
for (j = 2; j < 16; j += 2) {
row[0] = _mm256_loadu_si256((__m256i*)right + j);
row[1] = _mm256_loadu_si256((__m256i*)right + j + 1);
row[2] = _mm256_unpacklo_epi16(row[0], row[1]);
row[3] = _mm256_unpackhi_epi16(row[0], row[1]);
row[0] = _mm256_permute2x128_si256(row[2], row[3], 0 + 32);
row[1] = _mm256_permute2x128_si256(row[2], row[3], 1 + 48);
for (i = 0; i < 16; i += 2) {
even = _mm256_set1_epi32(((int32_t*)left)[stride * i + j / 2]);
accu[i][0] = _mm256_add_epi32(accu[i][0], _mm256_madd_epi16(even, row[0]));
accu[i][1] = _mm256_add_epi32(accu[i][1], _mm256_madd_epi16(even, row[1]));
odd = _mm256_set1_epi32(((int32_t*)left)[stride * (i + 1) + j / 2]);
accu[i + 1][0] = _mm256_add_epi32(accu[i + 1][0], _mm256_madd_epi16(odd, row[0]));
accu[i + 1][1] = _mm256_add_epi32(accu[i + 1][1], _mm256_madd_epi16(odd, row[1]));
__m256i prod1 = _mm256_madd_epi16(left_slice_lo, right_slices_lo);
__m256i prod2 = _mm256_madd_epi16(left_slice_hi, right_slices_lo);
__m256i prod3 = _mm256_madd_epi16(left_slice_lo, right_slices_hi);
__m256i prod4 = _mm256_madd_epi16(left_slice_hi, right_slices_hi);
accum1 = _mm256_add_epi32(accum1, prod1);
accum2 = _mm256_add_epi32(accum2, prod2);
accum3 = _mm256_add_epi32(accum3, prod3);
accum4 = _mm256_add_epi32(accum4, prod4);
}
}
__m256i accum1_tr = truncate(accum1, debias, shift);
__m256i accum2_tr = truncate(accum2, debias, shift);
__m256i accum3_tr = truncate(accum3, debias, shift);
__m256i accum4_tr = truncate(accum4, debias, shift);
for (i = 0; i < 16; ++i) {
__m256i result, first_half, second_half;
first_half = _mm256_srai_epi32(_mm256_add_epi32(accu[i][0], _mm256_set1_epi32(add)), shift);
second_half = _mm256_srai_epi32(_mm256_add_epi32(accu[i][1], _mm256_set1_epi32(add)), shift);
result = _mm256_permute4x64_epi64(_mm256_packs_epi32(first_half, second_half), 0 + 8 + 16 + 192);
_mm256_storeu_si256((__m256i*)dst + i, result);
__m256i out_up = _mm256_packs_epi32(accum1_tr, accum3_tr);
__m256i out_dn = _mm256_packs_epi32(accum2_tr, accum4_tr);
_mm256_storeu_si256((__m256i *)dst + dry + 0, out_up);
_mm256_storeu_si256((__m256i *)dst + dry + 1, out_dn);
}
}