[isp] Fix transform selection when MTS & ISP is used. Wrong transform was selected. Change mts parameter name to better reflect its purpose.

This commit is contained in:
siivonek 2022-09-26 17:52:34 +03:00 committed by Marko Viitanen
parent 85f6b00394
commit b4cc321349
4 changed files with 28 additions and 18 deletions

View file

@ -1584,7 +1584,7 @@ extern void uvg_get_tr_type(
const cu_info_t* tu, const cu_info_t* tu,
tr_type_t* hor_out, tr_type_t* hor_out,
tr_type_t* ver_out, tr_type_t* ver_out,
const int8_t mts_idx); const int8_t mts_type);
static void mts_dct_avx2( static void mts_dct_avx2(
const int8_t bitdepth, const int8_t bitdepth,
@ -1594,12 +1594,12 @@ static void mts_dct_avx2(
const int8_t height, const int8_t height,
const int16_t* input, const int16_t* input,
int16_t* output, int16_t* output,
const int8_t mts_idx) const int8_t mts_type)
{ {
tr_type_t type_hor; tr_type_t type_hor;
tr_type_t type_ver; tr_type_t type_ver;
uvg_get_tr_type(width, height, color, tu, &type_hor, &type_ver, mts_idx); uvg_get_tr_type(width, height, color, tu, &type_hor, &type_ver, mts_type);
if (type_hor == DCT2 && type_ver == DCT2 && !tu->lfnst_idx && width == height) if (type_hor == DCT2 && type_ver == DCT2 && !tu->lfnst_idx && width == height)
{ {
@ -1625,12 +1625,12 @@ static void mts_idct_avx2(
const int8_t height, const int8_t height,
const int16_t* input, const int16_t* input,
int16_t* output, int16_t* output,
const int8_t mts_idx) const int8_t mts_type)
{ {
tr_type_t type_hor; tr_type_t type_hor;
tr_type_t type_ver; tr_type_t type_ver;
uvg_get_tr_type(width, height, color, tu, &type_hor, &type_ver, mts_idx); uvg_get_tr_type(width, height, color, tu, &type_hor, &type_ver, mts_type);
if (type_hor == DCT2 && type_ver == DCT2 && width == height) if (type_hor == DCT2 && type_ver == DCT2 && width == height)
{ {

View file

@ -2505,7 +2505,7 @@ void uvg_get_tr_type(
const cu_info_t* tu, const cu_info_t* tu,
tr_type_t* hor_out, tr_type_t* hor_out,
tr_type_t* ver_out, tr_type_t* ver_out,
const int8_t mts_idx) const int8_t mts_type)
{ {
*hor_out = DCT2; *hor_out = DCT2;
*ver_out = DCT2; *ver_out = DCT2;
@ -2515,14 +2515,20 @@ void uvg_get_tr_type(
return; return;
} }
const bool explicit_mts = mts_idx == UVG_MTS_BOTH || (tu->type == CU_INTRA ? mts_idx == UVG_MTS_INTRA : (mts_idx == UVG_MTS_INTER && tu->type == CU_INTER)); const bool explicit_mts = mts_type == UVG_MTS_BOTH || (tu->type == CU_INTRA ? mts_type == UVG_MTS_INTRA : (mts_type == UVG_MTS_INTER && tu->type == CU_INTER));
const bool implicit_mts = tu->type == CU_INTRA && (mts_idx == UVG_MTS_IMPLICIT || mts_idx == UVG_MTS_INTER); const bool implicit_mts = tu->type == CU_INTRA && (mts_type == UVG_MTS_IMPLICIT || mts_type == UVG_MTS_INTER);
assert(!(explicit_mts && implicit_mts)); assert(!(explicit_mts && implicit_mts));
const bool is_isp = tu->type == CU_INTRA && tu->intra.isp_mode && color == COLOR_Y ? tu->intra.isp_mode : 0;
const int8_t lfnst_idx = color == COLOR_Y ? tu->lfnst_idx : tu->cr_lfnst_idx;
// const bool is_sbt = cu->type == CU_INTER && tu->sbt && color == COLOR_Y; // TODO: check SBT here when implemented
if (implicit_mts) if (is_isp && lfnst_idx) {
return;
}
if (implicit_mts || (is_isp && explicit_mts))
{ {
// ISP_TODO: do these apply for ISP blocks?
bool width_ok = width >= 4 && width <= 16; bool width_ok = width >= 4 && width <= 16;
bool height_ok = height >= 4 && height <= 16; bool height_ok = height >= 4 && height <= 16;
@ -2537,6 +2543,10 @@ void uvg_get_tr_type(
return; return;
} }
/*
TODO: SBT HANDLING
*/
if (explicit_mts) if (explicit_mts)
{ {
if (tu->tr_idx > MTS_SKIP) { if (tu->tr_idx > MTS_SKIP) {
@ -2555,12 +2565,12 @@ static void mts_dct_generic(
const int8_t height, const int8_t height,
const int16_t* input, const int16_t* input,
int16_t* output, int16_t* output,
const int8_t mts_idx) const int8_t mts_type)
{ {
tr_type_t type_hor; tr_type_t type_hor;
tr_type_t type_ver; tr_type_t type_ver;
uvg_get_tr_type(width, height, color, tu, &type_hor, &type_ver, mts_idx); uvg_get_tr_type(width, height, color, tu, &type_hor, &type_ver, mts_type);
if (type_hor == DCT2 && type_ver == DCT2 && !tu->lfnst_idx && !tu->cr_lfnst_idx && width == height) if (type_hor == DCT2 && type_ver == DCT2 && !tu->lfnst_idx && !tu->cr_lfnst_idx && width == height)
{ {
@ -2610,12 +2620,12 @@ static void mts_idct_generic(
const int8_t height, const int8_t height,
const int16_t* input, const int16_t* input,
int16_t* output, int16_t* output,
const int8_t mts_idx) const int8_t mts_type)
{ {
tr_type_t type_hor; tr_type_t type_hor;
tr_type_t type_ver; tr_type_t type_ver;
uvg_get_tr_type(width, height, color, tu, &type_hor, &type_ver, mts_idx); uvg_get_tr_type(width, height, color, tu, &type_hor, &type_ver, mts_type);
if (type_hor == DCT2 && type_ver == DCT2 && !tu->lfnst_idx && !tu->cr_lfnst_idx && width == height) if (type_hor == DCT2 && type_ver == DCT2 && !tu->lfnst_idx && !tu->cr_lfnst_idx && width == height)
{ {

View file

@ -60,7 +60,7 @@ void(*uvg_mts_dct)(int8_t bitdepth,
int8_t height, int8_t height,
const int16_t *input, const int16_t *input,
int16_t *output, int16_t *output,
const int8_t mts_idx); const int8_t mts_type);
void(*uvg_mts_idct)(int8_t bitdepth, void(*uvg_mts_idct)(int8_t bitdepth,
color_t color, color_t color,
@ -69,7 +69,7 @@ void(*uvg_mts_idct)(int8_t bitdepth,
int8_t height, int8_t height,
const int16_t *input, const int16_t *input,
int16_t *output, int16_t *output,
const int8_t mts_idx); const int8_t mts_type);
int uvg_strategy_register_dct(void* opaque, uint8_t bitdepth) { int uvg_strategy_register_dct(void* opaque, uint8_t bitdepth) {

View file

@ -68,7 +68,7 @@ typedef void (mts_dct_func)(
int8_t height, int8_t height,
const int16_t* input, const int16_t* input,
int16_t* output, int16_t* output,
const int8_t mts_idx); const int8_t mts_type);
extern mts_dct_func* uvg_mts_dct; extern mts_dct_func* uvg_mts_dct;
@ -80,7 +80,7 @@ typedef void (mts_idct_func)(
int8_t height, int8_t height,
const int16_t* input, const int16_t* input,
int16_t* output, int16_t* output,
const int8_t mts_idx); const int8_t mts_type);
extern mts_idct_func* uvg_mts_idct; extern mts_idct_func* uvg_mts_idct;