Retain data as much in YMM registers as possible

This seems to make it a whole lot quicker
This commit is contained in:
Pauli Oikkonen 2019-10-15 19:43:11 +03:00
parent 9589baccac
commit 4a921cbdb5

View file

@ -861,17 +861,7 @@ static void matmul_32x32_a_bt(const __m256i *__restrict a,
__m256i a_lo = a[jd + 0];
__m256i a_hi = a[jd + 1];
// Each lane here is the 4 i32's that add up to one number in the result
// vector:
// 00 | 08
// 01 | 09
// ...
// 07 | 0f
// 10 | 18
// 11 | 19
// ...
// 17 | 1f
__m256i rp1[16];
__m256i res_32[4];
for (i = 0; i < 16; i += 4) {
// Loop in order:
// 0, 1, 2, 3, 4, 5, 6, 7,
@ -935,48 +925,46 @@ static void matmul_32x32_a_bt(const __m256i *__restrict a,
__m256i s3b_lo = _mm256_permute2x128_si256(sum3, sumb, 0x20);
__m256i s3b_hi = _mm256_permute2x128_si256(sum3, sumb, 0x31);
rp1[i + 0] = _mm256_add_epi32(s08_lo, s08_hi);
rp1[i + 1] = _mm256_add_epi32(s19_lo, s19_hi);
rp1[i + 2] = _mm256_add_epi32(s2a_lo, s2a_hi);
rp1[i + 3] = _mm256_add_epi32(s3b_lo, s3b_hi);
// Each lane here is the 4 i32's that add up to one number in the result
// vector:
// 00 | 08
// 01 | 09
// ...
// 07 | 0f
// 10 | 18
// 11 | 19
// ...
// 17 | 1f
__m256i rp1_0 = _mm256_add_epi32(s08_lo, s08_hi);
__m256i rp1_1 = _mm256_add_epi32(s19_lo, s19_hi);
__m256i rp1_2 = _mm256_add_epi32(s2a_lo, s2a_hi);
__m256i rp1_3 = _mm256_add_epi32(s3b_lo, s3b_hi);
// rp2_0 and rp2_1 values on different iterations
// 00 00 01 01 | 08 08 09 09
// 02 02 03 03 | 0a 0a 0b 0b
//
// 04 04 05 05 | 0c 0c 0d 0d
// 06 06 07 07 | 0e 0e 0f 0f
//
// 10 10 11 11 | 18 18 19 19
// 12 12 13 13 | 1a 1a 1b 1b
//
// 16 16 17 17 | 1e 1e 1f 1f
// 16 16 17 17 | 1e 1e 1f 1f
__m256i rp2_0 = _mm256_hadd_epi32(rp1_0, rp1_1);
__m256i rp2_1 = _mm256_hadd_epi32(rp1_2, rp1_3);
// rp3 values on different iters
// 00 01 02 03 | 08 09 0a 0b
// 04 05 06 07 | 0c 0d 0e 0f
// 10 11 12 13 | 18 19 1a 1b
// 14 15 16 17 | 1c 1d 1e 1f
__m256i rp3 = _mm256_hadd_epi32(rp2_0, rp2_1);
res_32[i >> 2] = truncate(rp3, debias, shift);
}
// 00 00 01 01 | 08 08 09 09
// 02 02 03 03 | 0a 0a 0b 0b
// 04 04 05 05 | 0c 0c 0d 0d
// 06 06 07 07 | 0e 0e 0f 0f
// 10 10 11 11 | 18 18 19 19
// ...
// 16 16 17 17 | 1e 1e 1f 1f
__m256i rp2[8] = {
_mm256_hadd_epi32(rp1[ 0], rp1[ 1]),
_mm256_hadd_epi32(rp1[ 2], rp1[ 3]),
_mm256_hadd_epi32(rp1[ 4], rp1[ 5]),
_mm256_hadd_epi32(rp1[ 6], rp1[ 7]),
_mm256_hadd_epi32(rp1[ 8], rp1[ 9]),
_mm256_hadd_epi32(rp1[10], rp1[11]),
_mm256_hadd_epi32(rp1[12], rp1[13]),
_mm256_hadd_epi32(rp1[14], rp1[15]),
};
// 00 01 02 03 | 08 09 0a 0b
// 04 05 06 07 | 0c 0d 0e 0f
// 10 11 12 13 | 18 19 1a 1b
// 14 15 16 17 | 1c 1d 1e 1f
__m256i rp3[4] = {
_mm256_hadd_epi32(rp2[0], rp2[1]),
_mm256_hadd_epi32(rp2[2], rp2[3]),
_mm256_hadd_epi32(rp2[4], rp2[5]),
_mm256_hadd_epi32(rp2[6], rp2[7]),
};
rp3[0] = truncate(rp3[0], debias, shift);
rp3[1] = truncate(rp3[1], debias, shift);
rp3[2] = truncate(rp3[2], debias, shift);
rp3[3] = truncate(rp3[3], debias, shift);
dst[jd + 0] = _mm256_packs_epi32(rp3[0], rp3[1]);
dst[jd + 1] = _mm256_packs_epi32(rp3[2], rp3[3]);
dst[jd + 0] = _mm256_packs_epi32(res_32[0], res_32[1]);
dst[jd + 1] = _mm256_packs_epi32(res_32[2], res_32[3]);
}
}