From 7e089f518dad56c9106967a347ceff1f549b31d2 Mon Sep 17 00:00:00 2001 From: Marko Viitanen Date: Wed, 21 Jul 2021 11:53:15 +0300 Subject: [PATCH] [mts] add optimized versions of DCT8 and DST7, inverse not yet working properly * Includes new unit tests for the mts --- build/kvazaar_tests/kvazaar_tests.vcxproj | 1 + .../kvazaar_tests.vcxproj.filters | 5 +- src/kvazaar.h | 10 + src/strategies/avx2/dct-avx2.c | 699 ++++++++++++++++++ src/strategies/generic/dct-generic.c | 18 +- src/strategies/strategies-dct.h | 28 +- tests/Makefile.am | 1 + tests/mts_tests.c | 221 ++++++ tests/tests_main.c | 2 + 9 files changed, 959 insertions(+), 26 deletions(-) create mode 100644 tests/mts_tests.c diff --git a/build/kvazaar_tests/kvazaar_tests.vcxproj b/build/kvazaar_tests/kvazaar_tests.vcxproj index 7c0e2007..120eabd9 100644 --- a/build/kvazaar_tests/kvazaar_tests.vcxproj +++ b/build/kvazaar_tests/kvazaar_tests.vcxproj @@ -100,6 +100,7 @@ + diff --git a/build/kvazaar_tests/kvazaar_tests.vcxproj.filters b/build/kvazaar_tests/kvazaar_tests.vcxproj.filters index a7b14138..13798e3e 100644 --- a/build/kvazaar_tests/kvazaar_tests.vcxproj.filters +++ b/build/kvazaar_tests/kvazaar_tests.vcxproj.filters @@ -42,6 +42,9 @@ Source Files + + Source Files + @@ -54,4 +57,4 @@ Header Files - + \ No newline at end of file diff --git a/src/kvazaar.h b/src/kvazaar.h index df3d50bb..317b2bed 100644 --- a/src/kvazaar.h +++ b/src/kvazaar.h @@ -230,6 +230,16 @@ enum kvz_mts { KVZ_MTS_IMPLICIT = 4, }; + +//MTS transform tags +typedef enum tr_type_t { + DCT2 = 0, + DCT8 = 1, + DST7 = 2, + NUM_TRANS_TYPE = 3, + DCT2_MTS = 4 +} tr_type_t; + enum kvz_scalinglist { KVZ_SCALING_LIST_OFF = 0, KVZ_SCALING_LIST_CUSTOM = 1, diff --git a/src/strategies/avx2/dct-avx2.c b/src/strategies/avx2/dct-avx2.c index dbd8e3b8..de90b3c7 100644 --- a/src/strategies/avx2/dct-avx2.c +++ b/src/strategies/avx2/dct-avx2.c @@ -44,6 +44,8 @@ extern const int16_t kvz_g_dct_8_t[8][8]; extern const int16_t kvz_g_dct_16_t[16][16]; extern const int16_t kvz_g_dct_32_t[32][32]; + + /* * \file * \brief AVX2 transformations. @@ -929,6 +931,699 @@ ITRANSFORM(dct, 32); #endif // KVZ_BIT_DEPTH == 8 #endif //COMPILE_INTEL_AVX2 + + +/*****************************************************/ +/********************** M T S ************************/ +/*****************************************************/ + +// DST-7 +#define DEFINE_DST7_P4_MATRIX(a,b,c,d) { \ + { a, b, c, d},\ + { c, c, 0, -c},\ + { d, -a, -c, b},\ + { b, -d, c, -a},\ +} + +#define DEFINE_DST7_P4_MATRIX_T(a,b,c,d) { \ + { a, c, d, b},\ + { b, c, -a, -d},\ + { c, 0, -c, c},\ + { d, -c, b, -a},\ +} + +#define DEFINE_DST7_P8_MATRIX(a,b,c,d,e,f,g,h) \ +{\ + { a, b, c, d, e, f, g, h},\ + { c, f, h, e, b, -a, -d, -g},\ + { e, g, b, -c, -h, -d, a, f},\ + { g, c, -d, -f, a, h, b, -e},\ + { h, -a, -g, b, f, -c, -e, d},\ + { f, -e, -a, g, -d, -b, h, -c},\ + { d, -h, e, -a, -c, g, -f, b},\ + { b, -d, f, -h, g, -e, c, -a},\ +} + +#define DEFINE_DST7_P8_MATRIX_T(a,b,c,d,e,f,g,h) \ +{\ + { a, c, e, g, h, f, d, b,},\ + { b, f, g, c, -a, -e, -h, -d,},\ + { c, h, b, -d, -g, -a, e, f,},\ + { d, e, -c, -f, b, g, -a, -h,},\ + { e, b, -h, a, f, -d, -c, g,},\ + { f, -a, -d, h, -c, -b, g, -e,},\ + { g, -d, a, b, -e, h, -f, c,},\ + { h, -g, f, -e, d, -c, b, -a,},\ +}\ + +#define DEFINE_DST7_P16_MATRIX(a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p) \ +{ \ + { a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p}, \ + { c, f, i, l, o, o, l, i, f, c, 0, -c, -f, -i, -l, -o}, \ + { e, j, o, m, h, c, -b, -g, -l, -p, -k, -f, -a, d, i, n}, \ + { g, n, l, e, -b, -i, -p, -j, -c, d, k, o, h, a, -f, -m}, \ + { i, o, f, -c, -l, -l, -c, f, o, i, 0, -i, -o, -f, c, l}, \ + { k, k, 0, -k, -k, 0, k, k, 0, -k, -k, 0, k, k, 0, -k}, \ + { m, g, -f, -n, -a, l, h, -e, -o, -b, k, i, -d, -p, -c, j}, \ + { o, c, -l, -f, i, i, -f, -l, c, o, 0, -o, -c, l, f, -i}, \ + { p, -a, -o, b, n, -c, -m, d, l, -e, -k, f, j, -g, -i, h}, \ + { n, -e, -i, j, d, -o, a, m, -f, -h, k, c, -p, b, l, -g}, \ + { l, -i, -c, o, -f, -f, o, -c, -i, l, 0, -l, i, c, -o, f}, \ + { j, -m, c, g, -p, f, d, -n, i, a, -k, l, -b, -h, o, -e}, \ + { h, -p, i, -a, -g, o, -j, b, f, -n, k, -c, -e, m, -l, d}, \ + { f, -l, o, -i, c, c, -i, o, -l, f, 0, -f, l, -o, i, -c}, \ + { d, -h, l, -p, m, -i, e, -a, -c, g, -k, o, -n, j, -f, b}, \ + { b, -d, f, -h, j, -l, n, -p, o, -m, k, -i, g, -e, c, -a}, \ +} + +#define DEFINE_DST7_P16_MATRIX_T(a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p) \ +{ \ + {a, c, e, g, i, k, m, o, p, n, l, j, h, f, d, b,},\ + {b, f, j, n, o, k, g, c, -a, -e, -i, -m, -p, -l, -h, -d,},\ + {c, i, o, l, f, 0, -f, -l, -o, -i, -c, c, i, o, l, f,},\ + {d, l, m, e, -c, -k, -n, -f, b, j, o, g, -a, -i, -p, -h,},\ + {e, o, h, -b, -l, -k, -a, i, n, d, -f, -p, -g, c, m, j,},\ + {f, o, c, -i, -l, 0, l, i, -c, -o, -f, f, o, c, -i, -l,},\ + {g, l, -b, -p, -c, k, h, -f, -m, a, o, d, -j, -i, e, n,},\ + {h, i, -g, -j, f, k, -e, -l, d, m, -c, -n, b, o, -a, -p,},\ + {i, f, -l, -c, o, 0, -o, c, l, -f, -i, i, f, -l, -c, o,},\ + {j, c, -p, d, i, -k, -b, o, -e, -h, l, a, -n, f, g, -m,},\ + {k, 0, -k, k, 0, -k, k, 0, -k, k, 0, -k, k, 0, -k, k,},\ + {l, -c, -f, o, -i, 0, i, -o, f, c, -l, l, -c, -f, o, -i,},\ + {m, -f, -a, h, -o, k, -d, -c, j, -p, i, -b, -e, l, -n, g,},\ + {n, -i, d, a, -f, k, -p, l, -g, b, c, -h, m, -o, j, -e,},\ + {o, -l, i, -f, c, 0, -c, f, -i, l, -o, o, -l, i, -f, c,},\ + {p, -o, n, -m, l, -k, j, -i, h, -g, f, -e, d, -c, b, -a,},\ +} + + + +#define DEFINE_DST7_P32_MATRIX(a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z,A,B,C,D,E,F) \ +{ \ + {a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, A, B, C, D, E, F}, \ + {c, f, i, l, o, r, u, x, A, D, F, C, z, w, t, q, n, k, h, e, b, -a, -d, -g, -j, -m, -p, -s, -v, -y, -B, -E}, \ + {e, j, o, t, y, D, D, y, t, o, j, e, 0, -e, -j, -o, -t, -y, -D, -D, -y, -t, -o, -j, -e, 0, e, j, o, t, y, D}, \ + {g, n, u, B, D, w, p, i, b, -e, -l, -s, -z, -F, -y, -r, -k, -d, c, j, q, x, E, A, t, m, f, -a, -h, -o, -v, -C}, \ + {i, r, A, C, t, k, b, -g, -p, -y, -E, -v, -m, -d, e, n, w, F, x, o, f, -c, -l, -u, -D, -z, -q, -h, a, j, s, B}, \ + {k, v, F, u, j, -a, -l, -w, -E, -t, -i, b, m, x, D, s, h, -c, -n, -y, -C, -r, -g, d, o, z, B, q, f, -e, -p, -A}, \ + {m, z, z, m, 0, -m, -z, -z, -m, 0, m, z, z, m, 0, -m, -z, -z, -m, 0, m, z, z, m, 0, -m, -z, -z, -m, 0, m, z}, \ + {o, D, t, e, -j, -y, -y, -j, e, t, D, o, 0, -o, -D, -t, -e, j, y, y, j, -e, -t, -D, -o, 0, o, D, t, e, -j, -y}, \ + {q, E, n, -c, -t, -B, -k, f, w, y, h, -i, -z, -v, -e, l, C, s, b, -o, -F, -p, a, r, D, m, -d, -u, -A, -j, g, x}, \ + {s, A, h, -k, -D, -p, c, v, x, e, -n, -F, -m, f, y, u, b, -q, -C, -j, i, B, r, -a, -t, -z, -g, l, E, o, -d, -w}, \ + {u, w, b, -s, -y, -d, q, A, f, -o, -C, -h, m, E, j, -k, -F, -l, i, D, n, -g, -B, -p, e, z, r, -c, -x, -t, a, v}, \ + {w, s, -d, -A, -o, h, E, k, -l, -D, -g, p, z, c, -t, -v, a, x, r, -e, -B, -n, i, F, j, -m, -C, -f, q, y, b, -u}, \ + {y, o, -j, -D, -e, t, t, -e, -D, -j, o, y, 0, -y, -o, j, D, e, -t, -t, e, D, j, -o, -y, 0, y, o, -j, -D, -e, t}, \ + {A, k, -p, -v, e, F, f, -u, -q, j, B, a, -z, -l, o, w, -d, -E, -g, t, r, -i, -C, -b, y, m, -n, -x, c, D, h, -s}, \ + {C, g, -v, -n, o, u, -h, -B, a, D, f, -w, -m, p, t, -i, -A, b, E, e, -x, -l, q, s, -j, -z, c, F, d, -y, -k, r}, \ + {E, c, -B, -f, y, i, -v, -l, s, o, -p, -r, m, u, -j, -x, g, A, -d, -D, a, F, b, -C, -e, z, h, -w, -k, t, n, -q}, \ + {F, -a, -E, b, D, -c, -C, d, B, -e, -A, f, z, -g, -y, h, x, -i, -w, j, v, -k, -u, l, t, -m, -s, n, r, -o, -q, p}, \ + {D, -e, -y, j, t, -o, -o, t, j, -y, -e, D, 0, -D, e, y, -j, -t, o, o, -t, -j, y, e, -D, 0, D, -e, -y, j, t, -o}, \ + {B, -i, -s, r, j, -A, -a, C, -h, -t, q, k, -z, -b, D, -g, -u, p, l, -y, -c, E, -f, -v, o, m, -x, -d, F, -e, -w, n}, \ + {z, -m, -m, z, 0, -z, m, m, -z, 0, z, -m, -m, z, 0, -z, m, m, -z, 0, z, -m, -m, z, 0, -z, m, m, -z, 0, z, -m}, \ + {x, -q, -g, E, -j, -n, A, -c, -u, t, d, -B, m, k, -D, f, r, -w, -a, y, -p, -h, F, -i, -o, z, -b, -v, s, e, -C, l}, \ + {v, -u, -a, w, -t, -b, x, -s, -c, y, -r, -d, z, -q, -e, A, -p, -f, B, -o, -g, C, -n, -h, D, -m, -i, E, -l, -j, F, -k}, \ + {t, -y, e, o, -D, j, j, -D, o, e, -y, t, 0, -t, y, -e, -o, D, -j, -j, D, -o, -e, y, -t, 0, t, -y, e, o, -D, j}, \ + {r, -C, k, g, -y, v, -d, -n, F, -o, -c, u, -z, h, j, -B, s, -a, -q, D, -l, -f, x, -w, e, m, -E, p, b, -t, A, -i}, \ + {p, -F, q, -a, -o, E, -r, b, n, -D, s, -c, -m, C, -t, d, l, -B, u, -e, -k, A, -v, f, j, -z, w, -g, -i, y, -x, h}, \ + {n, -B, w, -i, -e, s, -F, r, -d, -j, x, -A, m, a, -o, C, -v, h, f, -t, E, -q, c, k, -y, z, -l, -b, p, -D, u, -g}, \ + {l, -x, C, -q, e, g, -s, E, -v, j, b, -n, z, -A, o, -c, -i, u, -F, t, -h, -d, p, -B, y, -m, a, k, -w, D, -r, f}, \ + {j, -t, D, -y, o, -e, -e, o, -y, D, -t, j, 0, -j, t, -D, y, -o, e, e, -o, y, -D, t, -j, 0, j, -t, D, -y, o, -e}, \ + {h, -p, x, -F, y, -q, i, -a, -g, o, -w, E, -z, r, -j, b, f, -n, v, -D, A, -s, k, -c, -e, m, -u, C, -B, t, -l, d}, \ + {f, -l, r, -x, D, -C, w, -q, k, -e, -a, g, -m, s, -y, E, -B, v, -p, j, -d, -b, h, -n, t, -z, F, -A, u, -o, i, -c}, \ + {d, -h, l, -p, t, -x, B, -F, C, -y, u, -q, m, -i, e, -a, -c, g, -k, o, -s, w, -A, E, -D, z, -v, r, -n, j, -f, b}, \ + {b, -d, f, -h, j, -l, n, -p, r, -t, v, -x, z, -B, D, -F, E, -C, A, -y, w, -u, s, -q, o, -m, k, -i, g, -e, c, -a}, \ +} + +#define DEFINE_DST7_P32_MATRIX_T(a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z,A,B,C,D,E,F) \ +{ \ + {a, c, e, g, i, k, m, o, q, s, u, w, y, A, C, E, F, D, B, z, x, v, t, r, p, n, l, j, h, f, d, b,},\ + {b, f, j, n, r, v, z, D, E, A, w, s, o, k, g, c, -a, -e, -i, -m, -q, -u, -y, -C, -F, -B, -x, -t, -p, -l, -h, -d,},\ + {c, i, o, u, A, F, z, t, n, h, b, -d, -j, -p, -v, -B, -E, -y, -s, -m, -g, -a, e, k, q, w, C, D, x, r, l, f,},\ + {d, l, t, B, C, u, m, e, -c, -k, -s, -A, -D, -v, -n, -f, b, j, r, z, E, w, o, g, -a, -i, -q, -y, -F, -x, -p, -h,},\ + {e, o, y, D, t, j, 0, -j, -t, -D, -y, -o, -e, e, o, y, D, t, j, 0, -j, -t, -D, -y, -o, -e, e, o, y, D, t, j,},\ + {f, r, D, w, k, -a, -m, -y, -B, -p, -d, h, t, F, u, i, -c, -o, -A, -z, -n, -b, j, v, E, s, g, -e, -q, -C, -x, -l,},\ + {g, u, D, p, b, -l, -z, -y, -k, c, q, E, t, f, -h, -v, -C, -o, -a, m, A, x, j, -d, -r, -F, -s, -e, i, w, B, n,},\ + {h, x, y, i, -g, -w, -z, -j, f, v, A, k, -e, -u, -B, -l, d, t, C, m, -c, -s, -D, -n, b, r, E, o, -a, -q, -F, -p,},\ + {i, A, t, b, -p, -E, -m, e, w, x, f, -l, -D, -q, a, s, B, j, -h, -z, -u, -c, o, F, n, -d, -v, -y, -g, k, C, r,},\ + {j, D, o, -e, -y, -t, 0, t, y, e, -o, -D, -j, j, D, o, -e, -y, -t, 0, t, y, e, -o, -D, -j, j, D, o, -e, -y, -t,},\ + {k, F, j, -l, -E, -i, m, D, h, -n, -C, -g, o, B, f, -p, -A, -e, q, z, d, -r, -y, -c, s, x, b, -t, -w, -a, u, v,},\ + {l, C, e, -s, -v, b, z, o, -i, -F, -h, p, y, a, -w, -r, f, D, k, -m, -B, -d, t, u, -c, -A, -n, j, E, g, -q, -x,},\ + {m, z, 0, -z, -m, m, z, 0, -z, -m, m, z, 0, -z, -m, m, z, 0, -z, -m, m, z, 0, -z, -m, m, z, 0, -z, -m, m, z,},\ + {n, w, -e, -F, -d, x, m, -o, -v, f, E, c, -y, -l, p, u, -g, -D, -b, z, k, -q, -t, h, C, a, -A, -j, r, s, -i, -B,},\ + {o, t, -j, -y, e, D, 0, -D, -e, y, j, -t, -o, o, t, -j, -y, e, D, 0, -D, -e, y, j, -t, -o, o, t, -j, -y, e, D,},\ + {p, q, -o, -r, n, s, -m, -t, l, u, -k, -v, j, w, -i, -x, h, y, -g, -z, f, A, -e, -B, d, C, -c, -D, b, E, -a, -F,},\ + {q, n, -t, -k, w, h, -z, -e, C, b, -F, a, D, -d, -A, g, x, -j, -u, m, r, -p, -o, s, l, -v, -i, y, f, -B, -c, E,},\ + {r, k, -y, -d, F, -c, -z, j, s, -q, -l, x, e, -E, b, A, -i, -t, p, m, -w, -f, D, -a, -B, h, u, -o, -n, v, g, -C,},\ + {s, h, -D, c, x, -n, -m, y, b, -C, i, r, -t, -g, E, -d, -w, o, l, -z, -a, B, -j, -q, u, f, -F, e, v, -p, -k, A,},\ + {t, e, -D, j, o, -y, 0, y, -o, -j, D, -e, -t, t, e, -D, j, o, -y, 0, y, -o, -j, D, -e, -t, t, e, -D, j, o, -y,},\ + {u, b, -y, q, f, -C, m, j, -F, i, n, -B, e, r, -x, a, v, -t, -c, z, -p, -g, D, -l, -k, E, -h, -o, A, -d, -s, w,},\ + {v, -a, -t, x, -c, -r, z, -e, -p, B, -g, -n, D, -i, -l, F, -k, -j, E, -m, -h, C, -o, -f, A, -q, -d, y, -s, -b, w, -u,},\ + {w, -d, -o, E, -l, -g, z, -t, a, r, -B, i, j, -C, q, b, -u, y, -f, -m, F, -n, -e, x, -v, c, p, -D, k, h, -A, s,},\ + {x, -g, -j, A, -u, d, m, -D, r, -a, -p, F, -o, -b, s, -C, l, e, -v, z, -i, -h, y, -w, f, k, -B, t, -c, -n, E, -q,},\ + {y, -j, -e, t, -D, o, 0, -o, D, -t, e, j, -y, y, -j, -e, t, -D, o, 0, -o, D, -t, e, j, -y, y, -j, -e, t, -D, o,},\ + {z, -m, 0, m, -z, z, -m, 0, m, -z, z, -m, 0, m, -z, z, -m, 0, m, -z, z, -m, 0, m, -z, z, -m, 0, m, -z, z, -m,},\ + {A, -p, e, f, -q, B, -z, o, -d, -g, r, -C, y, -n, c, h, -s, D, -x, m, -b, -i, t, -E, w, -l, a, j, -u, F, -v, k,},\ + {B, -s, j, -a, -h, q, -z, D, -u, l, -c, -f, o, -x, F, -w, n, -e, -d, m, -v, E, -y, p, -g, -b, k, -t, C, -A, r, -i,},\ + {C, -v, o, -h, a, f, -m, t, -A, E, -x, q, -j, c, d, -k, r, -y, F, -z, s, -l, e, b, -i, p, -w, D, -B, u, -n, g,},\ + {D, -y, t, -o, j, -e, 0, e, -j, o, -t, y, -D, D, -y, t, -o, j, -e, 0, e, -j, o, -t, y, -D, D, -y, t, -o, j, -e,},\ + {E, -B, y, -v, s, -p, m, -j, g, -d, a, b, -e, h, -k, n, -q, t, -w, z, -C, F, -D, A, -x, u, -r, o, -l, i, -f, c,},\ + {F, -E, D, -C, B, -A, z, -y, x, -w, v, -u, t, -s, r, -q, p, -o, n, -m, l, -k, j, -i, h, -g, f, -e, d, -c, b, -a,},\ +} + +// DCT-8 +#define DEFINE_DCT8_P4_MATRIX(a,b,c,d) \ +{ \ + {a, b, c, d}, \ + {b, 0, -b, -b}, \ + {c, -b, -d, a}, \ + {d, -b, a, -c}, \ +} + +#define DEFINE_DCT8_P8_MATRIX(a,b,c,d,e,f,g,h) \ +{ \ + {a, b, c, d, e, f, g, h}, \ + {b, e, h, -g, -d, -a, -c, -f}, \ + {c, h, -e, -a, -f, g, b, d}, \ + {d, -g, -a, -h, c, e, -f, -b}, \ + {e, -d, -f, c, g, -b, -h, a}, \ + {f, -a, g, e, -b, h, d, -c}, \ + {g, -c, b, -f, -h, d, -a, e}, \ + {h, -f, d, -b, a, -c, e, -g}, \ +} + +#define DEFINE_DCT8_P16_MATRIX(a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p) \ +{ \ + {a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p}, \ + {b, e, h, k, n, 0, -n, -k, -h, -e, -b, -b, -e, -h, -k, -n}, \ + {c, h, m, -p, -k, -f, -a, -e, -j, -o, n, i, d, b, g, l}, \ + {d, k, -p, -i, -b, -f, -m, n, g, a, h, o, -l, -e, -c, -j}, \ + {e, n, -k, -b, -h, 0, h, b, k, -n, -e, -e, -n, k, b, h}, \ + {f, 0, -f, -f, 0, f, f, 0, -f, -f, 0, f, f, 0, -f, -f}, \ + {g, -n, -a, -m, h, f, -o, -b, -l, i, e, -p, -c, -k, j, d}, \ + {h, -k, -e, n, b, 0, -b, -n, e, k, -h, -h, k, e, -n, -b}, \ + {i, -h, -j, g, k, -f, -l, e, m, -d, -n, c, o, -b, -p, a}, \ + {j, -e, -o, a, -n, -f, i, k, -d, -p, b, -m, -g, h, l, -c}, \ + {k, -b, n, h, -e, 0, e, -h, -n, b, -k, -k, b, -n, -h, e}, \ + {l, -b, i, o, -e, f, -p, -h, c, -m, -k, a, -j, -n, d, -g}, \ + {m, -e, d, -l, -n, f, -c, k, o, -g, b, -j, -p, h, -a, i}, \ + {n, -h, b, -e, k, 0, -k, e, -b, h, -n, -n, h, -b, e, -k}, \ + {o, -k, g, -c, b, -f, j, -n, -p, l, -h, d, -a, e, -i, m}, \ + {p, -n, l, -j, h, -f, d, -b, a, -c, e, -g, i, -k, m, -o}, \ +} + + +#define DEFINE_DCT8_P32_MATRIX(a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z,A,B,C,D,E,F) \ +{ \ + {a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, A, B, C, D, E, F}, \ + {b, e, h, k, n, q, t, w, z, C, F, -E, -B, -y, -v, -s, -p, -m, -j, -g, -d, -a, -c, -f, -i, -l, -o, -r, -u, -x, -A, -D}, \ + {c, h, m, r, w, B, 0, -B, -w, -r, -m, -h, -c, -c, -h, -m, -r, -w, -B, 0, B, w, r, m, h, c, c, h, m, r, w, B}, \ + {d, k, r, y, F, -A, -t, -m, -f, -b, -i, -p, -w, -D, C, v, o, h, a, g, n, u, B, -E, -x, -q, -j, -c, -e, -l, -s, -z}, \ + {e, n, w, F, -y, -p, -g, -c, -l, -u, -D, A, r, i, a, j, s, B, -C, -t, -k, -b, -h, -q, -z, E, v, m, d, f, o, x}, \ + {f, q, B, -A, -p, -e, -g, -r, -C, z, o, d, h, s, D, -y, -n, -c, -i, -t, -E, x, m, b, j, u, F, -w, -l, -a, -k, -v}, \ + {g, t, 0, -t, -g, -g, -t, 0, t, g, g, t, 0, -t, -g, -g, -t, 0, t, g, g, t, 0, -t, -g, -g, -t, 0, t, g, g, t}, \ + {h, w, -B, -m, -c, -r, 0, r, c, m, B, -w, -h, -h, -w, B, m, c, r, 0, -r, -c, -m, -B, w, h, h, w, -B, -m, -c, -r}, \ + {i, z, -w, -f, -l, -C, t, c, o, F, -q, -a, -r, E, n, d, u, -B, -k, -g, -x, y, h, j, A, -v, -e, -m, -D, s, b, p}, \ + {j, C, -r, -b, -u, z, g, m, F, -o, -e, -x, w, d, p, -E, -l, -h, -A, t, a, s, -B, -i, -k, -D, q, c, v, -y, -f, -n}, \ + {k, F, -m, -i, -D, o, g, B, -q, -e, -z, s, c, x, -u, -a, -v, w, b, t, -y, -d, -r, A, f, p, -C, -h, -n, E, j, l}, \ + {l, -E, -h, -p, A, d, t, -w, -a, -x, s, e, B, -o, -i, -F, k, m, -D, -g, -q, z, c, u, -v, -b, -y, r, f, C, -n, -j}, \ + {m, -B, -c, -w, r, h, 0, -h, -r, w, c, B, -m, -m, B, c, w, -r, -h, 0, h, r, -w, -c, -B, m, m, -B, -c, -w, r, h}, \ + {n, -y, -c, -D, i, s, -t, -h, E, d, x, -o, -m, z, b, C, -j, -r, u, g, -F, -e, -w, p, l, -A, -a, -B, k, q, -v, -f}, \ + {o, -v, -h, C, a, D, -g, -w, n, p, -u, -i, B, b, E, -f, -x, m, q, -t, -j, A, c, F, -e, -y, l, r, -s, -k, z, d}, \ + {p, -s, -m, v, j, -y, -g, B, d, -E, -a, -F, c, C, -f, -z, i, w, -l, -t, o, q, -r, -n, u, k, -x, -h, A, e, -D, -b}, \ + {q, -p, -r, o, s, -n, -t, m, u, -l, -v, k, w, -j, -x, i, y, -h, -z, g, A, -f, -B, e, C, -d, -D, c, E, -b, -F, a}, \ + {r, -m, -w, h, B, -c, 0, c, -B, -h, w, m, -r, -r, m, w, -h, -B, c, 0, -c, B, h, -w, -m, r, r, -m, -w, h, B, -c}, \ + {s, -j, -B, a, -C, -i, t, r, -k, -A, b, -D, -h, u, q, -l, -z, c, -E, -g, v, p, -m, -y, d, -F, -f, w, o, -n, -x, e}, \ + {t, -g, 0, g, -t, -t, g, 0, -g, t, t, -g, 0, g, -t, -t, g, 0, -g, t, t, -g, 0, g, -t, -t, g, 0, -g, t, t, -g}, \ + {u, -d, B, n, -k, -E, g, -r, -x, a, -y, -q, h, -F, -j, o, A, -c, v, t, -e, C, m, -l, -D, f, -s, -w, b, -z, -p, i}, \ + {v, -a, w, u, -b, x, t, -c, y, s, -d, z, r, -e, A, q, -f, B, p, -g, C, o, -h, D, n, -i, E, m, -j, F, l, -k}, \ + {w, -c, r, B, -h, m, 0, -m, h, -B, -r, c, -w, -w, c, -r, -B, h, -m, 0, m, -h, B, r, -c, w, w, -c, r, B, -h, m}, \ + {x, -f, m, -E, -q, b, -t, -B, j, -i, A, u, -c, p, F, -n, e, -w, -y, g, -l, D, r, -a, s, C, -k, h, -z, -v, d, -o}, \ + {y, -i, h, -x, -z, j, -g, w, A, -k, f, -v, -B, l, -e, u, C, -m, d, -t, -D, n, -c, s, E, -o, b, -r, -F, p, -a, q}, \ + {z, -l, c, -q, E, u, -g, h, -v, -D, p, -b, m, -A, -y, k, -d, r, -F, -t, f, -i, w, C, -o, a, -n, B, x, -j, e, -s}, \ + {A, -o, c, -j, v, F, -t, h, -e, q, -C, -y, m, -a, l, -x, -D, r, -f, g, -s, E, w, -k, b, -n, z, B, -p, d, -i, u}, \ + {B, -r, h, -c, m, -w, 0, w, -m, c, -h, r, -B, -B, r, -h, c, -m, w, 0, -w, m, -c, h, -r, B, B, -r, h, -c, m, -w}, \ + {C, -u, m, -e, d, -l, t, -B, -D, v, -n, f, -c, k, -s, A, E, -w, o, -g, b, -j, r, -z, -F, x, -p, h, -a, i, -q, y}, \ + {D, -x, r, -l, f, -a, g, -m, s, -y, E, C, -w, q, -k, e, -b, h, -n, t, -z, F, B, -v, p, -j, d, -c, i, -o, u, -A}, \ + {E, -A, w, -s, o, -k, g, -c, b, -f, j, -n, r, -v, z, -D, -F, B, -x, t, -p, l, -h, d, -a, e, -i, m, -q, u, -y, C}, \ + {F, -D, B, -z, x, -v, t, -r, p, -n, l, -j, h, -f, d, -b, a, -c, e, -g, i, -k, m, -o, q, -s, u, -w, y, -A, C, -E}, \ +} + + +// DST-7 +ALIGNED(64) const int16_t kvz_g_dst7_4[4][4] = DEFINE_DST7_P4_MATRIX(29, 55, 74, 84); +ALIGNED(64) const int16_t kvz_g_dst7_8[8][8] = DEFINE_DST7_P8_MATRIX(17, 32, 46, 60, 71, 78, 85, 86); +ALIGNED(64) const int16_t kvz_g_dst7_16[16][16] = DEFINE_DST7_P16_MATRIX(8, 17, 25, 33, 40, 48, 55, 62, 68, 73, 77, 81, 85, 87, 88, 88); +ALIGNED(64) const int16_t kvz_g_dst7_32[32][32] = DEFINE_DST7_P32_MATRIX(4, 9, 13, 17, 21, 26, 30, 34, 38, 42, 46, 50, 53, 56, 60, 63, 66, 68, 72, 74, 77, 78, 80, 82, 84, 85, 86, 87, 88, 89, 90, 90); + +ALIGNED(64) const int16_t kvz_g_dst7_4_t[4][4] = DEFINE_DST7_P4_MATRIX_T(29, 55, 74, 84); +ALIGNED(64) const int16_t kvz_g_dst7_8_t[8][8] = DEFINE_DST7_P8_MATRIX_T(17, 32, 46, 60, 71, 78, 85, 86); +ALIGNED(64) const int16_t kvz_g_dst7_16_t[16][16] = DEFINE_DST7_P16_MATRIX_T(8, 17, 25, 33, 40, 48, 55, 62, 68, 73, 77, 81, 85, 87, 88, 88); +ALIGNED(64) const int16_t kvz_g_dst7_32_t[32][32] = DEFINE_DST7_P32_MATRIX_T(4, 9, 13, 17, 21, 26, 30, 34, 38, 42, 46, 50, 53, 56, 60, 63, 66, 68, 72, 74, 77, 78, 80, 82, 84, 85, 86, 87, 88, 89, 90, 90); + +// DCT-8 +ALIGNED(64) const int16_t kvz_g_dct8_4[4][4] = DEFINE_DCT8_P4_MATRIX(84, 74, 55, 29); +ALIGNED(64) const int16_t kvz_g_dct8_8[8][8] = DEFINE_DCT8_P8_MATRIX(86, 85, 78, 71, 60, 46, 32, 17); +ALIGNED(64) const int16_t kvz_g_dct8_16[16][16] = DEFINE_DCT8_P16_MATRIX(88, 88, 87, 85, 81, 77, 73, 68, 62, 55, 48, 40, 33, 25, 17, 8); +ALIGNED(64) const int16_t kvz_g_dct8_32[32][32] = DEFINE_DCT8_P32_MATRIX(90, 90, 89, 88, 87, 86, 85, 84, 82, 80, 78, 77, 74, 72, 68, 66, 63, 60, 56, 53, 50, 46, 42, 38, 34, 30, 26, 21, 17, 13, 9, 4); + +const int16_t* kvz_g_mts_input[2][3][5] = { + { + {&kvz_g_dct_4[0][0], &kvz_g_dct_8[0][0], &kvz_g_dct_16[0][0], &kvz_g_dct_32[0][0], NULL}, + {&kvz_g_dct8_4[0][0], &kvz_g_dct8_8[0][0], &kvz_g_dct8_16[0][0], &kvz_g_dct8_32[0][0], NULL}, + {&kvz_g_dst7_4[0][0], &kvz_g_dst7_8[0][0], &kvz_g_dst7_16[0][0], &kvz_g_dst7_32[0][0], NULL} + }, + { + {&kvz_g_dct_4_t[0][0], &kvz_g_dct_8_t[0][0], &kvz_g_dct_16_t[0][0], &kvz_g_dct_32_t[0][0], NULL}, + {&kvz_g_dct8_4[0][0], &kvz_g_dct8_8[0][0], &kvz_g_dct8_16[0][0], &kvz_g_dct8_32[0][0], NULL}, + {&kvz_g_dst7_4_t[0][0], &kvz_g_dst7_8_t[0][0], &kvz_g_dst7_16_t[0][0], &kvz_g_dst7_32_t[0][0], NULL} + }, +}; + +static void mts_dct_4x4_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth) +{ + const int height = 4; + const int width = 4; + + const int log2_width_minus2 = kvz_g_convert_to_bit[width]; + + const int32_t shift_1st = log2_width_minus2 + bitdepth - 7; + const int32_t shift_2nd = log2_width_minus2 + 8; + + const int16_t* tdct = kvz_g_mts_input[1][type_hor][0]; + const int16_t* dct = kvz_g_mts_input[0][type_ver][0]; + + __m256i tdct_v = _mm256_load_si256((const __m256i*) tdct); + __m256i dct_v = _mm256_load_si256((const __m256i*) dct); + __m256i in_v = _mm256_load_si256((const __m256i*)input); + + __m256i tmp = mul_clip_matrix_4x4_avx2(in_v, tdct_v, shift_1st); + __m256i result = mul_clip_matrix_4x4_avx2(dct_v, tmp, shift_2nd); + + _mm256_store_si256((__m256i*)output, result); +} + +static void mts_idct_4x4_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth) +{ + int32_t shift_1st = 7; + int32_t shift_2nd = 12 - (bitdepth - 8); + + const int16_t* tdct = kvz_g_mts_input[1][type_ver][0]; + const int16_t* dct = kvz_g_mts_input[0][type_hor][0]; + + __m256i tdct_v = _mm256_load_si256((const __m256i*)tdct); + __m256i dct_v = _mm256_load_si256((const __m256i*) dct); + __m256i in_v = _mm256_load_si256((const __m256i*)input); + + __m256i tmp = mul_clip_matrix_4x4_avx2(tdct_v, in_v, shift_1st); + __m256i result = mul_clip_matrix_4x4_avx2(tmp, dct_v, shift_2nd); + + _mm256_store_si256((__m256i*)output, result); +} + +static void mts_dct_8x8_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth) +{ + int32_t shift_1st = kvz_g_convert_to_bit[8] + 1 + (bitdepth - 8); + int32_t shift_2nd = kvz_g_convert_to_bit[8] + 8; + + const int16_t* dct1 = kvz_g_mts_input[0][type_hor][1]; + const int16_t* dct2 = kvz_g_mts_input[0][type_ver][1]; + + __m256i tmpres[4]; + + matmul_8x8_a_bt_t(input, dct1, tmpres, shift_1st); + matmul_8x8_a_bt(dct2, tmpres, output, shift_2nd); +} + +static void mts_idct_8x8_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth) +{ + int32_t shift_1st = 7; + int32_t shift_2nd = 12 - (bitdepth - 8); + ALIGNED(64) int16_t tmp[8 * 8]; + + const int16_t* tdct = kvz_g_mts_input[1][type_ver][1]; + const int16_t* dct = kvz_g_mts_input[0][type_hor][1]; + + mul_clip_matrix_8x8_avx2(tdct, input, tmp, shift_1st); + mul_clip_matrix_8x8_avx2(tmp, dct, output, shift_2nd); +} + + +static void mts_dct_16x16_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth) +{ + int32_t shift_1st = kvz_g_convert_to_bit[16] + 1 + (bitdepth - 8); + int32_t shift_2nd = kvz_g_convert_to_bit[16] + 8; + + const int16_t* dct1 = kvz_g_mts_input[0][type_hor][2]; + const int16_t* dct2 = kvz_g_mts_input[0][type_ver][2]; + + /* + * Multiply input by the tranpose of DCT matrix into tmpres, and DCT matrix + * by tmpres - this is then our output matrix + * + * It's easier to implement an AVX2 matrix multiplication if you can multiply + * the left term with the transpose of the right term. Here things are stored + * row-wise, not column-wise, so we can effectively read DCT_T column-wise + * into YMM registers by reading DCT row-wise. Also because of this, the + * first multiplication is hacked to produce the transpose of the result + * instead, since it will be used in similar fashion as the right operand + * in the second multiplication. + */ + + const __m256i* d_v = (const __m256i*)dct1; + const __m256i* d_v2 = (const __m256i*)dct2; + const __m256i* i_v = (const __m256i*)input; + __m256i* o_v = (__m256i*)output; + __m256i tmp[16]; + + // Hack! (A * B^T)^T = B * A^T, so we can dispatch the transpose-produciong + // multiply completely + matmul_16x16_a_bt(d_v, i_v, tmp, shift_1st); + matmul_16x16_a_bt(d_v2, tmp, o_v, shift_2nd); +} + +static void partial_butterfly_inverse_16_mts_avx2(const int16_t* src, int16_t* dst, int32_t shift, tr_type_t type) +{ + __m256i tsrc[16]; + + const uint32_t width = 16; + + const int16_t* tdct = kvz_g_mts_input[1][type][2]; + + const __m256i eo_signmask = _mm256_setr_epi32(1, 1, 1, 1, -1, -1, -1, -1); + const __m256i eeo_signmask = _mm256_setr_epi32(1, 1, -1, -1, -1, -1, 1, 1); + const __m256i o_signmask = _mm256_set1_epi32(-1); + + const __m256i final_shufmask = _mm256_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, + 6, 7, 4, 5, 2, 3, 0, 1, + 14, 15, 12, 13, 10, 11, 8, 9); + transpose_16x16(src, (int16_t*)tsrc); + + const __m256i dct_cols[8] = { + _mm256_load_si256((const __m256i*)tdct + 0), + _mm256_load_si256((const __m256i*)tdct + 1), + _mm256_load_si256((const __m256i*)tdct + 2), + _mm256_load_si256((const __m256i*)tdct + 3), + _mm256_load_si256((const __m256i*)tdct + 4), + _mm256_load_si256((const __m256i*)tdct + 5), + _mm256_load_si256((const __m256i*)tdct + 6), + _mm256_load_si256((const __m256i*)tdct + 7), + }; + + // These contain: D1,0 D3,0 D5,0 D7,0 D9,0 Db,0 Dd,0 Df,0 | D1,4 D3,4 D5,4 D7,4 D9,4 Db,4 Dd,4 Df,4 + // D1,1 D3,1 D5,1 D7,1 D9,1 Db,1 Dd,1 Df,1 | D1,5 D3,5 D5,5 D7,5 D9,5 Db,5 Dd,5 Df,5 + // D1,2 D3,2 D5,2 D7,2 D9,2 Db,2 Dd,2 Df,2 | D1,6 D3,6 D5,6 D7,6 D9,6 Db,6 Dd,6 Df,6 + // D1,3 D3,3 D5,3 D7,3 D9,3 Db,3 Dd,3 Df,3 | D1,7 D3,7 D5,7 D7,7 D9,7 Db,7 Dd,7 Df,7 + __m256i dct_col_odds[4]; + for (uint32_t j = 0; j < 4; j++) { + dct_col_odds[j] = extract_combine_odds(dct_cols[j + 0], dct_cols[j + 4]); + } + for (uint32_t j = 0; j < width; j++) { + __m256i col = tsrc[j]; + __m256i odds = extract_odds(col); + + __m256i o04 = _mm256_madd_epi16(odds, dct_col_odds[0]); + __m256i o15 = _mm256_madd_epi16(odds, dct_col_odds[1]); + __m256i o26 = _mm256_madd_epi16(odds, dct_col_odds[2]); + __m256i o37 = _mm256_madd_epi16(odds, dct_col_odds[3]); + + __m256i o0145 = _mm256_hadd_epi32(o04, o15); + __m256i o2367 = _mm256_hadd_epi32(o26, o37); + + __m256i o = _mm256_hadd_epi32(o0145, o2367); + + // D0,2 D0,6 D1,2 D1,6 D1,a D1,e D0,a D0,e | D2,2 D2,6 D3,2 D3,6 D3,a D3,e D2,a D2,e + __m256i d_db2 = extract_26ae(dct_cols); + + // 2 6 2 6 a e a e | 2 6 2 6 a e a e + __m256i t_db2 = extract_26ae_vec(col); + + __m256i eo_parts = _mm256_madd_epi16(d_db2, t_db2); + __m256i eo_parts2 = _mm256_shuffle_epi32(eo_parts, _MM_SHUFFLE(0, 1, 2, 3)); + + // EO0 EO1 EO1 EO0 | EO2 EO3 EO3 EO2 + __m256i eo = _mm256_add_epi32(eo_parts, eo_parts2); + __m256i eo2 = _mm256_permute4x64_epi64(eo, _MM_SHUFFLE(1, 3, 2, 0)); + __m256i eo3 = _mm256_sign_epi32(eo2, eo_signmask); + + __m256i d_db4 = extract_d048c(dct_cols); + __m256i t_db4 = extract_d048c_vec(col); + __m256i eee_eeo = _mm256_madd_epi16(d_db4, t_db4); + + __m256i eee_eee = _mm256_permute4x64_epi64(eee_eeo, _MM_SHUFFLE(3, 0, 3, 0)); + __m256i eeo_eeo1 = _mm256_permute4x64_epi64(eee_eeo, _MM_SHUFFLE(1, 2, 1, 2)); + + __m256i eeo_eeo2 = _mm256_sign_epi32(eeo_eeo1, eeo_signmask); + + // EE0 EE1 EE2 EE3 | EE3 EE2 EE1 EE0 + __m256i ee = _mm256_add_epi32(eee_eee, eeo_eeo2); + __m256i e = _mm256_add_epi32(ee, eo3); + + __m256i o_neg = _mm256_sign_epi32(o, o_signmask); + __m256i o_lo = _mm256_blend_epi32(o, o_neg, 0xf0); // 1111 0000 + __m256i o_hi = _mm256_blend_epi32(o, o_neg, 0x0f); // 0000 1111 + + __m256i res_lo = _mm256_add_epi32(e, o_lo); + __m256i res_hi = _mm256_add_epi32(e, o_hi); + __m256i res_hi2 = _mm256_permute4x64_epi64(res_hi, _MM_SHUFFLE(1, 0, 3, 2)); + + __m256i res_lo_t = truncate_inv(res_lo, shift); + __m256i res_hi_t = truncate_inv(res_hi2, shift); + + __m256i res_16_1 = _mm256_packs_epi32(res_lo_t, res_hi_t); + __m256i final = _mm256_shuffle_epi8(res_16_1, final_shufmask); + + _mm256_store_si256((__m256i*)dst + j, final); + } +} + +static void mts_idct_16x16_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth) +{ + int32_t shift_1st = 7; + int32_t shift_2nd = 12 - (bitdepth - 8); + ALIGNED(64) int16_t tmp[16 * 16]; + + partial_butterfly_inverse_16_mts_avx2(input, tmp, shift_1st, type_ver); + partial_butterfly_inverse_16_mts_avx2(tmp, output, shift_2nd, type_hor); +} + +// 32x32 matrix multiplication with value clipping. +// Parameters: Two 32x32 matrices containing 16-bit values in consecutive addresses, +// destination for the result and the shift value for clipping. +static void mul_clip_matrix_32x32_mts_avx2(const int16_t* left, + const int16_t* right, + int16_t* dst, + const int32_t shift, int skip_line, int skip_line2) +{ + const int32_t add = 1 << (shift - 1); + const __m256i debias = _mm256_set1_epi32(add); + + const int reduced_line = 32 - skip_line; + const int cutoff = 32 - skip_line2; + + const uint32_t* l_32 = (const uint32_t*)left; + const __m256i* r_v = (const __m256i*)right; + __m256i* dst_v = (__m256i*)dst; + + __m256i accu[128] = { _mm256_setzero_si256() }; + size_t i, j; + + for (j = 0; j < 64; j += 4) { + const __m256i r0 = r_v[j + 0]; + const __m256i r1 = r_v[j + 1]; + const __m256i r2 = r_v[j + 2]; + const __m256i r3 = r_v[j + 3]; + + __m256i r02l = _mm256_unpacklo_epi16(r0, r2); + __m256i r02h = _mm256_unpackhi_epi16(r0, r2); + __m256i r13l = _mm256_unpacklo_epi16(r1, r3); + __m256i r13h = _mm256_unpackhi_epi16(r1, r3); + + __m256i r02_07 = _mm256_permute2x128_si256(r02l, r02h, 0x20); + __m256i r02_8f = _mm256_permute2x128_si256(r02l, r02h, 0x31); + + __m256i r13_07 = _mm256_permute2x128_si256(r13l, r13h, 0x20); + __m256i r13_8f = _mm256_permute2x128_si256(r13l, r13h, 0x31); + + for (i = 0; i < 32; i += 2) { + size_t acc_base = i << 2; + + uint32_t curr_e = l_32[(i + 0) * (32 / 2) + (j >> 2)]; + uint32_t curr_o = l_32[(i + 1) * (32 / 2) + (j >> 2)]; + + __m256i even = _mm256_set1_epi32(curr_e); + __m256i odd = _mm256_set1_epi32(curr_o); + + __m256i p_e0 = _mm256_madd_epi16(even, r02_07); + __m256i p_e1 = _mm256_madd_epi16(even, r02_8f); + __m256i p_e2 = _mm256_madd_epi16(even, r13_07); + __m256i p_e3 = _mm256_madd_epi16(even, r13_8f); + + __m256i p_o0 = _mm256_madd_epi16(odd, r02_07); + __m256i p_o1 = _mm256_madd_epi16(odd, r02_8f); + __m256i p_o2 = _mm256_madd_epi16(odd, r13_07); + __m256i p_o3 = _mm256_madd_epi16(odd, r13_8f); + + accu[acc_base + 0] = _mm256_add_epi32(p_e0, accu[acc_base + 0]); + accu[acc_base + 1] = _mm256_add_epi32(p_e1, accu[acc_base + 1]); + accu[acc_base + 2] = _mm256_add_epi32(p_e2, accu[acc_base + 2]); + accu[acc_base + 3] = _mm256_add_epi32(p_e3, accu[acc_base + 3]); + + accu[acc_base + 4] = _mm256_add_epi32(p_o0, accu[acc_base + 4]); + accu[acc_base + 5] = _mm256_add_epi32(p_o1, accu[acc_base + 5]); + accu[acc_base + 6] = _mm256_add_epi32(p_o2, accu[acc_base + 6]); + accu[acc_base + 7] = _mm256_add_epi32(p_o3, accu[acc_base + 7]); + } + } + + for (i = 0; i < 32; i++) { + size_t acc_base = i << 2; + size_t dst_base = i << 1; + + __m256i q0 = truncate_avx2(accu[acc_base + 0], debias, shift); + __m256i q1 = truncate_avx2(accu[acc_base + 1], debias, shift); + __m256i q2 = truncate_avx2(accu[acc_base + 2], debias, shift); + __m256i q3 = truncate_avx2(accu[acc_base + 3], debias, shift); + + __m256i h01 = _mm256_packs_epi32(q0, q1); + __m256i h23 = _mm256_packs_epi32(q2, q3); + + h01 = _mm256_permute4x64_epi64(h01, _MM_SHUFFLE(3, 1, 2, 0)); + h23 = _mm256_permute4x64_epi64(h23, _MM_SHUFFLE(3, 1, 2, 0)); + + _mm256_store_si256(dst_v + dst_base + 0, h01); + _mm256_store_si256(dst_v + dst_base + 1, h23); + } + + if (skip_line) + { + int16_t* dst2 = dst + reduced_line; + for (j = 0; j < cutoff; j++) + { + memset(dst2, 0, sizeof(int16_t) * skip_line); + dst2 += 32; + } + } + + if (skip_line2) + { + int16_t* dst2 = dst + 32 * cutoff; + memset(dst2, 0, sizeof(int16_t) * 32 * skip_line2); + } +} + +static void mts_dct_32x32_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth) +{ + int32_t shift_1st = kvz_g_convert_to_bit[32] + 1 + (bitdepth - 8); + int32_t shift_2nd = kvz_g_convert_to_bit[32] + 8; + ALIGNED(64) int16_t tmp[32 * 32]; + + const int16_t* tdct = kvz_g_mts_input[1][type_hor][3]; + const int16_t* dct = kvz_g_mts_input[0][type_ver][3]; + + const int skip_width = (type_hor != DCT2) ? 16 : 0; + const int skip_height = (type_ver != DCT2) ? 16 : 0; + + mul_clip_matrix_32x32_mts_avx2(input, tdct, tmp, shift_1st, skip_width, 0 ); + mul_clip_matrix_32x32_mts_avx2(dct, tmp, output, shift_2nd, skip_width, skip_height); +} + + +static void mts_idct_32x32_avx2(const int16_t* input, int16_t* output, tr_type_t type_hor, tr_type_t type_ver, uint8_t bitdepth) +{ + int32_t shift_1st = 7; + int32_t shift_2nd = 12 - (bitdepth - 8); + ALIGNED(64) int16_t tmp[32 * 32]; + const int16_t* tdct = kvz_g_mts_input[1][type_ver][3]; + const int16_t* dct = kvz_g_mts_input[0][type_hor][3]; + + const int skip_width = (type_hor != DCT2) ? 16 : 0; + const int skip_height = (type_ver != DCT2) ? 16 : 0; + + mul_clip_matrix_32x32_mts_avx2(tdct, input, tmp, shift_1st, skip_width, skip_height); + mul_clip_matrix_32x32_mts_avx2(tmp, dct, output, shift_2nd, 0, skip_width); +} + +typedef void tr_func(const int16_t*, int16_t*, tr_type_t , tr_type_t , uint8_t); + +// ToDo: Enable MTS 2x2 and 64x64 transforms +static tr_func* dct_table[5] = { + mts_dct_4x4_avx2, mts_dct_8x8_avx2, mts_dct_16x16_avx2, mts_dct_32x32_avx2, NULL +}; + +static tr_func* idct_table[5] = { + mts_idct_4x4_avx2, mts_idct_8x8_avx2, mts_idct_16x16_avx2, mts_idct_32x32_avx2, NULL/*fastInverseDCT2_B64*/ +}; + + +extern void kvz_get_tr_type( + int8_t width, + color_t color, + const cu_info_t* tu, + tr_type_t* hor_out, + tr_type_t* ver_out, + const int8_t mts_idx); + +static void mts_dct_avx2( + const int8_t bitdepth, + const color_t color, + const cu_info_t* tu, + const int8_t width, + const int16_t* input, + int16_t* output, + const int8_t mts_idx) +{ + tr_type_t type_hor; + tr_type_t type_ver; + + kvz_get_tr_type(width, color, tu, &type_hor, &type_ver, mts_idx); + + if (type_hor == DCT2 && type_ver == DCT2) + { + dct_func* dct_func = kvz_get_dct_func(width, color, tu->type); + dct_func(bitdepth, input, output); + } + else + { + const int log2_width_minus2 = kvz_g_convert_to_bit[width]; + + tr_func* dct = dct_table[log2_width_minus2]; + + dct(input, output, type_hor, type_ver, bitdepth); + } +} + + +static void mts_idct_avx2( + const int8_t bitdepth, + const color_t color, + const cu_info_t* tu, + const int8_t width, + const int16_t* input, + int16_t* output, + const int8_t mts_idx) +{ + tr_type_t type_hor; + tr_type_t type_ver; + + kvz_get_tr_type(width, color, tu, &type_hor, &type_ver, mts_idx); + + if (type_hor == DCT2 && type_ver == DCT2) + { + dct_func* idct_func = kvz_get_idct_func(width, color, tu->type); + idct_func(bitdepth, input, output); + } + else + { + const int log2_width_minus2 = kvz_g_convert_to_bit[width]; + + tr_func* idct = idct_table[log2_width_minus2]; + + idct(input, output, type_hor, type_ver, bitdepth); + } +} + + int kvz_strategy_register_dct_avx2(void* opaque, uint8_t bitdepth) { bool success = true; @@ -948,6 +1643,10 @@ int kvz_strategy_register_dct_avx2(void* opaque, uint8_t bitdepth) success &= kvz_strategyselector_register(opaque, "idct_8x8", "avx2", 40, &matrix_idct_8x8_avx2); success &= kvz_strategyselector_register(opaque, "idct_16x16", "avx2", 40, &matrix_idct_16x16_avx2); success &= kvz_strategyselector_register(opaque, "idct_32x32", "avx2", 40, &matrix_idct_32x32_avx2); + + success &= kvz_strategyselector_register(opaque, "mts_dct", "avx2", 40, &mts_dct_avx2); + success &= kvz_strategyselector_register(opaque, "mts_idct", "avx2", 40, &mts_idct_avx2); + } #endif // KVZ_BIT_DEPTH == 8 #endif //COMPILE_INTEL_AVX2 diff --git a/src/strategies/generic/dct-generic.c b/src/strategies/generic/dct-generic.c index 59f460c4..d25af16b 100644 --- a/src/strategies/generic/dct-generic.c +++ b/src/strategies/generic/dct-generic.c @@ -2413,19 +2413,11 @@ static partial_tr_func* idct_table[3][5] = { }; -//MTS transform tags -typedef enum tr_type_t { - DCT2 = 0, - DCT8 = 1, - DST7 = 2, - NUM_TRANS_TYPE = 3, - DCT2_MTS = 4 -} tr_type_t; +const tr_type_t mts_subset_intra[4][2] = { { DST7, DST7 }, { DCT8, DST7 }, { DST7, DCT8 }, { DCT8, DCT8 } }; +const tr_type_t mts_subset_inter[2] = { DCT8, DST7 }; -static const tr_type_t mts_subset_intra[4][2] = { { DST7, DST7 }, { DCT8, DST7 }, { DST7, DCT8 }, { DCT8, DCT8 } }; -static const tr_type_t mts_subset_inter[2] = { DCT8, DST7 }; -static INLINE void get_tr_type( +void kvz_get_tr_type( int8_t width, color_t color, const cu_info_t* tu, @@ -2485,7 +2477,7 @@ static void mts_dct_generic( tr_type_t type_hor; tr_type_t type_ver; - get_tr_type(width, color, tu, &type_hor, &type_ver, mts_idx); + kvz_get_tr_type(width, color, tu, &type_hor, &type_ver, mts_idx); if (type_hor == DCT2 && type_ver == DCT2) { @@ -2524,7 +2516,7 @@ static void mts_idct_generic( tr_type_t type_hor; tr_type_t type_ver; - get_tr_type(width, color, tu, &type_hor, &type_ver, mts_idx); + kvz_get_tr_type(width, color, tu, &type_hor, &type_ver, mts_idx); if (type_hor == DCT2 && type_ver == DCT2) { diff --git a/src/strategies/strategies-dct.h b/src/strategies/strategies-dct.h index b75be31f..7e30c4eb 100644 --- a/src/strategies/strategies-dct.h +++ b/src/strategies/strategies-dct.h @@ -47,24 +47,28 @@ extern dct_func * kvz_idct_8x8; extern dct_func * kvz_idct_16x16; extern dct_func * kvz_idct_32x32; +typedef void (mts_dct_func)( + int8_t bitdepth, + color_t color, + const cu_info_t* tu, + int8_t width, + const int16_t* input, + int16_t* output, + const int8_t mts_idx); -extern void(*kvz_mts_dct)( +extern mts_dct_func* kvz_mts_dct; + +typedef void (mts_idct_func)( int8_t bitdepth, color_t color, - const cu_info_t *tu, + const cu_info_t* tu, int8_t width, - const int16_t *input, - int16_t *output, - const int8_t mts_idx); -extern void(*kvz_mts_idct)( - int8_t bitdepth, - color_t color, - const cu_info_t *tu, - int8_t width, - const int16_t *input, - int16_t *output, + const int16_t* input, + int16_t* output, const int8_t mts_idx); +extern mts_idct_func* kvz_mts_idct; + int kvz_strategy_register_dct(void* opaque, uint8_t bitdepth); dct_func * kvz_get_dct_func(int8_t width, color_t color, cu_type_t type); dct_func * kvz_get_idct_func(int8_t width, color_t color, cu_type_t type); diff --git a/tests/Makefile.am b/tests/Makefile.am index 8bc63a11..5ff06563 100644 --- a/tests/Makefile.am +++ b/tests/Makefile.am @@ -23,6 +23,7 @@ check_PROGRAMS = kvazaar_tests kvazaar_tests_SOURCES = \ coeff_sum_tests.c \ dct_tests.c \ + mts_tests.c \ intra_sad_tests.c \ mv_cand_tests.c \ sad_tests.c \ diff --git a/tests/mts_tests.c b/tests/mts_tests.c new file mode 100644 index 00000000..ac68055a --- /dev/null +++ b/tests/mts_tests.c @@ -0,0 +1,221 @@ +/***************************************************************************** + * This file is part of Kvazaar HEVC encoder. + * + * Copyright (C) 2013-2015 Tampere University of Technology and others (see + * COPYING file). + * + * Kvazaar is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License version 2.1 as + * published by the Free Software Foundation. + * + * Kvazaar is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with Kvazaar. If not, see . + ****************************************************************************/ + +#include "greatest/greatest.h" + +#include "test_strategies.h" + +#include "src/image.h" + +#include +#include + + +////////////////////////////////////////////////////////////////////////// +// MACROS +#define NUM_SIZES 4 +#define LCU_MAX_LOG_W 5 +#define LCU_MIN_LOG_W 2 +#define NUM_TRANSFORM 4 +#define NUM_TESTS NUM_TRANSFORM*NUM_SIZES + +////////////////////////////////////////////////////////////////////////// +// GLOBALS +static int16_t * dct_bufs[NUM_TESTS] = { 0 }; // SIMD aligned pointers. +static int16_t * dct_actual_bufs[NUM_TESTS] = { 0 }; // pointers returned by malloc. + +static int16_t dct_result[NUM_TRANSFORM][NUM_SIZES][LCU_WIDTH*LCU_WIDTH] = { { 0 } }; +static int16_t idct_result[NUM_TRANSFORM][NUM_SIZES][LCU_WIDTH*LCU_WIDTH] = { { 0 } }; + +static struct test_env_t { + int log_width; // for selecting dim from bufs + mts_dct_func* tested_func; + const strategy_t * strategy; + char msg[1024]; +} test_env; + + +////////////////////////////////////////////////////////////////////////// +// SETUP, TEARDOWN AND HELPER FUNCTIONS +static void init_gradient(int x_px, int y_px, int width, int slope, int16_t *buf) +{ + for (int y = 0; y < width; ++y) { + for (int x = 0; x < width; ++x) { + int diff_x = x_px - x; + int diff_y = y_px - y; + int val = slope * sqrt(diff_x * diff_x + diff_y * diff_y) + 0.5; + buf[y * width + x] = CLIP(0, 255, val); + } + } +} + + +static void setup_tests() +{ + for (int test = 0; test < NUM_TESTS; ++test) { + + dct_actual_bufs[test] = malloc(LCU_WIDTH*LCU_WIDTH*sizeof(int16_t) + SIMD_ALIGNMENT); + dct_bufs[test] = ALIGNED_POINTER(dct_actual_bufs[test], SIMD_ALIGNMENT); + } + + for (int test = 0; test < NUM_TESTS; ++test) { + const int width = LCU_WIDTH; + init_gradient(width, width, width, 255 / width, dct_bufs[test]); + } + + + + // Select buffer width according to function name for dct function. + int block = 0; + for (int s = 0; s < strategies.count; ++s) + { + strategy_t *strat = &strategies.strategies[s]; + mts_dct_func* mts_generic = 0; + if (strcmp(strat->type, "mts_dct") == 0 && + strcmp(strat->strategy_name, "generic") == 0) + { + mts_generic = strat->fptr; + for (block = 0; block < NUM_SIZES; block++) { + for (int trafo = 0; trafo < NUM_TRANSFORM; trafo++) { + cu_info_t tu; + tu.type = CU_INTRA; + tu.tr_idx = MTS_DST7_DST7 + trafo; + mts_generic(KVZ_BIT_DEPTH, COLOR_Y, &tu, 1 << (LCU_MIN_LOG_W + block), dct_bufs[trafo*NUM_SIZES+block], dct_result[trafo][block], KVZ_MTS_BOTH); + } + } + } + } + + block = 0; + for (int s = 0; s < strategies.count; ++s) + { + strategy_t *strat = &strategies.strategies[s]; + mts_idct_func* idct_generic = 0; + if (strcmp(strat->type, "mts_idct") == 0 && + strcmp(strat->strategy_name, "generic") == 0) + { + + idct_generic = strat->fptr; + for (block = 0; block < NUM_SIZES; block++) { + for (int trafo = 0; trafo < NUM_TRANSFORM; trafo++) { + cu_info_t tu; + tu.type = CU_INTRA; + tu.tr_idx = MTS_DST7_DST7 + trafo; + idct_generic(KVZ_BIT_DEPTH, COLOR_Y, &tu, 1 << (LCU_MIN_LOG_W + block), dct_bufs[trafo * NUM_SIZES + block], idct_result[trafo][block], KVZ_MTS_BOTH); + } + } + + } + } +} + +static void tear_down_tests() +{ + for (int test = 0; test < NUM_TESTS; ++test) { + free(dct_actual_bufs[test]); + } +} + + +////////////////////////////////////////////////////////////////////////// +// TESTS +TEST dct(void) +{ + char testname[100]; + for (int blocksize = 0; blocksize < NUM_SIZES; blocksize++) { + for (int trafo = 0; trafo < NUM_TRANSFORM; trafo++) { + sprintf(testname, "Block: %d x %d, trafo: %d", 1 << (LCU_MIN_LOG_W + blocksize), 1 << (LCU_MIN_LOG_W + blocksize), trafo); + cu_info_t tu; + tu.type = CU_INTRA; + tu.tr_idx = MTS_DST7_DST7 + trafo; + + int16_t* buf = dct_bufs[trafo * NUM_SIZES + blocksize]; + ALIGNED(32) int16_t test_result[LCU_WIDTH * LCU_WIDTH] = { 0 }; + + test_env.tested_func(KVZ_BIT_DEPTH, COLOR_Y, &tu, 1 << (LCU_MIN_LOG_W + blocksize), buf, test_result, KVZ_MTS_BOTH); + + for (int i = 0; i < LCU_WIDTH * LCU_WIDTH; ++i) { + ASSERT_EQm(testname, test_result[i], dct_result[trafo][blocksize][i]); + } + fprintf(stderr, "PASS: %s\r\n", testname); + } + } + + PASS(); +} + +TEST idct(void) +{ + char testname[100]; + for (int blocksize = 0; blocksize < NUM_SIZES; blocksize++) { + for (int trafo = 0; trafo < NUM_TRANSFORM; trafo++) { + sprintf(testname, "Block: %d x %d, trafo: %d", 1 << (LCU_MIN_LOG_W + blocksize), 1 << (LCU_MIN_LOG_W + blocksize), trafo); + cu_info_t tu; + tu.type = CU_INTRA; + tu.tr_idx = MTS_DST7_DST7 + trafo; + + int16_t* buf = dct_bufs[trafo * NUM_SIZES + blocksize]; + ALIGNED(32) int16_t test_result[LCU_WIDTH * LCU_WIDTH] = { 0 }; + + test_env.tested_func(KVZ_BIT_DEPTH, COLOR_Y, &tu, 1 << (LCU_MIN_LOG_W + blocksize), buf, test_result, KVZ_MTS_BOTH); + + for (int i = 0; i < LCU_WIDTH * LCU_WIDTH; ++i) { + ASSERT_EQm(testname, test_result[i], idct_result[trafo][blocksize][i]); + } + fprintf(stderr, "PASS: %s\r\n", testname); + } + } + + PASS(); +} + + +////////////////////////////////////////////////////////////////////////// +// TEST FIXTURES +SUITE(mts_tests) +{ + //SET_SETUP(sad_setup); + //SET_TEARDOWN(sad_teardown); + + setup_tests(); + + // Loop through all strategies picking out the intra sad ones and run + // select strategies though all tests + for (volatile unsigned i = 0; i < strategies.count; ++i) { + const strategy_t * strategy = &strategies.strategies[i]; + + test_env.tested_func = strategies.strategies[i].fptr; + test_env.strategy = strategy; + + // Call different tests depending on type of function. + // This allows for selecting a subset of tests with -t parameter. + if (strcmp(strategy->type, "mts_dct") == 0) + { + fprintf(stderr, "Test: %s\r\n", strategy->strategy_name); + RUN_TEST(dct); + } + else if (strcmp(strategy->type, "mts_idct") == 0) + { + fprintf(stderr, "Test: %s\r\n", strategy->strategy_name); + RUN_TEST(idct); + } + } + + tear_down_tests(); +} diff --git a/tests/tests_main.c b/tests/tests_main.c index 74fde444..1d8f01a2 100644 --- a/tests/tests_main.c +++ b/tests/tests_main.c @@ -28,6 +28,7 @@ extern SUITE(intra_sad_tests); extern SUITE(satd_tests); extern SUITE(speed_tests); extern SUITE(dct_tests); +extern SUITE(mts_tests); #endif //KVZ_BIT_DEPTH == 8 extern SUITE(coeff_sum_tests); @@ -44,6 +45,7 @@ int main(int argc, char **argv) RUN_SUITE(intra_sad_tests); RUN_SUITE(satd_tests); RUN_SUITE(dct_tests); + RUN_SUITE(mts_tests); if (greatest_info.suite_filter && greatest_name_match("speed", greatest_info.suite_filter))