[mip] Implement MIP search.

This commit is contained in:
siivonek 2022-01-20 00:11:50 +02:00
parent e672f9b24a
commit 59a86f339e
6 changed files with 194 additions and 39 deletions

View file

@ -169,6 +169,8 @@ typedef struct
int8_t mode;
int8_t mode_chroma;
uint8_t multi_ref_idx;
bool mip_flag;
bool mip_is_transposed;
} intra;
struct {
mv_t mv[2][2]; // \brief Motion vectors for L0 and L1

View file

@ -546,14 +546,14 @@ void kvz_predict_cclm(
}
void kvz_mip_boundary_downsampling(kvz_pixel* reduced_dst, const kvz_pixel* const ref_src, int src_len, int dst_len)
void kvz_mip_boundary_downsampling_1D(kvz_pixel* reduced_dst, const kvz_pixel* const ref_src, int src_len, int dst_len)
{
if (dst_len < src_len)
{
// Create reduced boundary by downsampling
uint16_t down_smp_factor = src_len / dst_len;
// Calculate floor log2. TODO: find a better / faster solution
// Calculate floor log2. MIP_TODO: find a better / faster solution
int tmp = 0;
if (down_smp_factor & 0xffff0000) {
down_smp_factor >>= 16;
@ -614,7 +614,7 @@ void kvz_mip_reduced_pred(kvz_pixel* const output,
const int input_size = 2 * red_bdry_size;
// Use local buffer for transposed result
kvz_pixel* out_buf_transposed = MALLOC(kvz_pixel, red_pred_size * red_pred_size); // TODO: get rid of MALLOC & FREE
kvz_pixel* out_buf_transposed = MALLOC(kvz_pixel, red_pred_size * red_pred_size); // MIP_TODO: get rid of MALLOC & FREE
kvz_pixel* const out_ptr = transpose ? out_buf_transposed : output;
int sum = 0;
@ -669,7 +669,7 @@ void kvz_mip_pred_upsampling_1D(kvz_pixel* const dst, const kvz_pixel* const src
const uint16_t boundary_step,
const uint16_t ups_factor)
{
// Calculate floor log2. TODO: find a better / faster solution
// Calculate floor log2. MIP_TODO: find a better / faster solution
uint16_t upsample_factor = ups_factor;
int tmp = 0;
if (upsample_factor & 0xffff0000) {
@ -737,15 +737,16 @@ void kvz_mip_pred_upsampling_1D(kvz_pixel* const dst, const kvz_pixel* const src
/** \brief Matrix weighted intra prediction.
*/
void kvz_mip_predict(encoder_state_t const* const state,
kvz_intra_references* const refs,
const uint16_t pred_block_width,
const uint16_t pred_block_height)
void kvz_mip_predict(encoder_state_t const* const state, kvz_intra_references* const refs,
const uint16_t pred_block_width, const uint16_t pred_block_height,
const color_t color,
kvz_pixel* dst,
const int mip_mode, const bool mip_transp)
{
// Separate this function into smaller bits if needed
kvz_pixel* result; // TODO: pass the dst buffer to this function
const int mode_idx = 0; // TODO: pass mode
kvz_pixel* result = dst;
const int mode_idx = mip_mode;
// *** INPUT PREP ***
@ -790,8 +791,8 @@ void kvz_mip_predict(encoder_state_t const* const state,
kvz_pixel* const top_reduced = &red_bdry[0];
kvz_pixel* const left_reduced = &red_bdry[red_bdry_size];
kvz_mip_boundary_downsampling(top_reduced, ref_samples_top, width, red_bdry_size);
kvz_mip_boundary_downsampling(left_reduced, ref_samples_left, height, red_bdry_size);
kvz_mip_boundary_downsampling_1D(top_reduced, ref_samples_top, width, red_bdry_size);
kvz_mip_boundary_downsampling_1D(left_reduced, ref_samples_left, height, red_bdry_size);
// Transposed reduced boundaries
kvz_pixel* const left_reduced_trans = &red_bdry_trans[0];
@ -822,7 +823,7 @@ void kvz_mip_predict(encoder_state_t const* const state,
// *** BLOCK PREDICT ***
const bool need_upsampling = (ups_hor_factor > 1) || (ups_ver_factor > 1);
const bool transpose = 0; // TODO: pass transpose
const bool transpose = mip_transp;
uint8_t* matrix;
switch (size_id) {
@ -839,7 +840,7 @@ void kvz_mip_predict(encoder_state_t const* const state,
assert(false && "Invalid MIP size id.");
}
kvz_pixel* red_pred_buffer = MALLOC(kvz_pixel, red_pred_size * red_pred_size); // TODO: get rid of MALLOC and FREE
kvz_pixel* red_pred_buffer = MALLOC(kvz_pixel, red_pred_size * red_pred_size); // MIP_TODO: get rid of MALLOC and FREE
kvz_pixel* const reduced_pred = need_upsampling ? red_pred_buffer : result;
const kvz_pixel* const reduced_bdry = transpose ? red_bdry_trans : red_bdry;
@ -1357,7 +1358,9 @@ static void intra_recon_tb_leaf(
cclm_parameters_t *cclm_params,
lcu_t *lcu,
color_t color,
uint8_t multi_ref_idx)
uint8_t multi_ref_idx,
bool use_mip,
bool mip_transp)
{
const kvz_config *cfg = &state->encoder_control->cfg;
const int shift = color == COLOR_Y ? 0 : 1;
@ -1368,6 +1371,7 @@ static void intra_recon_tb_leaf(
log2width -= 1;
}
const int width = 1 << log2width;
const int height = width; // TODO: proper height for non-square blocks
const int lcu_width = LCU_WIDTH >> shift;
const vector2d_t luma_px = { x, y };
@ -1404,7 +1408,13 @@ static void intra_recon_tb_leaf(
int stride = state->tile->frame->source->stride;
const bool filter_boundary = color == COLOR_Y && !(cfg->lossless && cfg->implicit_rdpcm);
if(intra_mode < 68) {
kvz_intra_predict(state, &refs, log2width, intra_mode, color, pred, filter_boundary, multi_ref_index);
if (use_mip) {
assert(intra_mode < 16 && "MIP mode must be between [0, 16]");
kvz_mip_predict(state, &refs, width, height, color, pred, intra_mode, mip_transp);
}
else {
kvz_intra_predict(state, &refs, log2width, intra_mode, color, pred, filter_boundary, multi_ref_index);
}
} else {
kvz_pixels_blit(&state->tile->frame->cclm_luma_rec[x / 2 + (y * stride) / 4], pred, width, width, stride / 2, width);
if(cclm_params == NULL) {
@ -1464,6 +1474,8 @@ void kvz_intra_recon_cu(
cu_info_t *cur_cu,
cclm_parameters_t *cclm_params,
uint8_t multi_ref_idx,
bool mip_flag,
bool mip_transp,
lcu_t *lcu)
{
const vector2d_t lcu_px = { SUB_SCU(x), SUB_SCU(y) };
@ -1472,6 +1484,8 @@ void kvz_intra_recon_cu(
cur_cu = LCU_GET_CU_AT_PX(lcu, lcu_px.x, lcu_px.y);
}
uint8_t multi_ref_index = multi_ref_idx;
bool use_mip = mip_flag;
bool mip_transposed = mip_transp;
// Reset CBFs because CBFs might have been set
// for depth earlier
@ -1489,10 +1503,10 @@ void kvz_intra_recon_cu(
const int32_t x2 = x + offset;
const int32_t y2 = y + offset;
kvz_intra_recon_cu(state, x, y, depth + 1, mode_luma, mode_chroma, NULL, NULL, multi_ref_index, lcu);
kvz_intra_recon_cu(state, x2, y, depth + 1, mode_luma, mode_chroma, NULL, NULL, multi_ref_index, lcu);
kvz_intra_recon_cu(state, x, y2, depth + 1, mode_luma, mode_chroma, NULL, NULL, multi_ref_index, lcu);
kvz_intra_recon_cu(state, x2, y2, depth + 1, mode_luma, mode_chroma, NULL, NULL, multi_ref_index, lcu);
kvz_intra_recon_cu(state, x, y, depth + 1, mode_luma, mode_chroma, NULL, NULL, multi_ref_index, use_mip, mip_transposed, lcu);
kvz_intra_recon_cu(state, x2, y, depth + 1, mode_luma, mode_chroma, NULL, NULL, multi_ref_index, use_mip, mip_transposed, lcu);
kvz_intra_recon_cu(state, x, y2, depth + 1, mode_luma, mode_chroma, NULL, NULL, multi_ref_index, use_mip, mip_transposed, lcu);
kvz_intra_recon_cu(state, x2, y2, depth + 1, mode_luma, mode_chroma, NULL, NULL, multi_ref_index, use_mip, mip_transposed, lcu);
// Propagate coded block flags from child CUs to parent CU.
uint16_t child_cbfs[3] = {
@ -1513,11 +1527,11 @@ void kvz_intra_recon_cu(
const bool has_chroma = mode_chroma != -1 && (x % 8 == 0 && y % 8 == 0);
// Process a leaf TU.
if (has_luma) {
intra_recon_tb_leaf(state, x, y, depth, mode_luma, cclm_params, lcu, COLOR_Y, multi_ref_index);
intra_recon_tb_leaf(state, x, y, depth, mode_luma, cclm_params, lcu, COLOR_Y, multi_ref_index, use_mip, mip_transposed);
}
if (has_chroma) {
intra_recon_tb_leaf(state, x, y, depth, mode_chroma, cclm_params, lcu, COLOR_U, 0);
intra_recon_tb_leaf(state, x, y, depth, mode_chroma, cclm_params, lcu, COLOR_V, 0);
intra_recon_tb_leaf(state, x, y, depth, mode_chroma, cclm_params, lcu, COLOR_U, 0, use_mip, mip_transposed);
intra_recon_tb_leaf(state, x, y, depth, mode_chroma, cclm_params, lcu, COLOR_V, 0, use_mip, mip_transposed);
}
kvz_quantize_lcu_residual(state, has_luma, has_chroma, x, y, depth, cur_cu, lcu, false);

View file

@ -130,6 +130,8 @@ void kvz_intra_recon_cu(
cu_info_t *cur_cu,
cclm_parameters_t* cclm_params,
uint8_t multi_ref_idx,
bool mip_flag,
bool mip_transp,
lcu_t *lcu);
@ -146,4 +148,15 @@ void kvz_predict_cclm(
kvz_intra_references* chroma_ref,
kvz_pixel* dst,
cclm_parameters_t* cclm_params
);
void kvz_mip_predict(
encoder_state_t const * const state,
kvz_intra_references * refs,
const uint16_t width,
const uint16_t height,
const color_t color,
kvz_pixel* dst,
const int mip_mode,
const bool mip_transp
);

View file

@ -727,14 +727,18 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth,
int8_t intra_trafo;
double intra_cost;
uint8_t multi_ref_index = 0;
bool mip_flag;
bool mip_transposed;
kvz_search_cu_intra(state, x, y, depth, lcu,
&intra_mode, &intra_trafo, &intra_cost, &multi_ref_index);
&intra_mode, &intra_trafo, &intra_cost, &multi_ref_index, &mip_flag, &mip_transposed);
if (intra_cost < cost) {
cost = intra_cost;
cur_cu->type = CU_INTRA;
cur_cu->part_size = depth > MAX_DEPTH ? SIZE_NxN : SIZE_2Nx2N;
cur_cu->intra.mode = intra_mode;
cur_cu->intra.multi_ref_idx = multi_ref_index;
cur_cu->intra.mip_flag = mip_flag;
cur_cu->intra.mip_is_transposed = mip_transposed;
//If the CU is not split from 64x64 block, the MTS is disabled for that CU.
cur_cu->tr_idx = (depth > 0) ? intra_trafo : 0;
@ -751,7 +755,9 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth,
x, y,
depth,
cur_cu->intra.mode, -1, // skip chroma
NULL, NULL, cur_cu->intra.multi_ref_idx, lcu);
NULL, NULL, cur_cu->intra.multi_ref_idx,
cur_cu->intra.mip_flag, cur_cu->intra.mip_is_transposed,
lcu);
downsample_cclm_rec(
state, x, y, cu_width / 2, cu_width / 2, lcu->rec.y, lcu->left_ref.y[64]
@ -773,7 +779,8 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth,
x & ~7, y & ~7, // TODO: as does this
depth,
-1, cur_cu->intra.mode_chroma, // skip luma
NULL, cclm_params, 0, lcu);
NULL, cclm_params, 0, cur_cu->intra.mip_flag, cur_cu->intra.mip_is_transposed,
lcu);
}
} else if (cur_cu->type == CU_INTER) {
@ -933,7 +940,8 @@ static double search_cu(encoder_state_t * const state, int x, int y, int depth,
x, y,
depth,
cur_cu->intra.mode, mode_chroma,
NULL,NULL, 0, lcu);
NULL,NULL, 0, cur_cu->intra.mip_flag, cur_cu->intra.mip_is_transposed,
lcu);
cost += kvz_cu_rd_cost_luma(state, x_local, y_local, depth, cur_cu, lcu);
if (has_chroma) {

View file

@ -333,7 +333,9 @@ static double search_intra_trdepth(encoder_state_t * const state,
x_px, y_px,
depth,
intra_mode, -1,
pred_cu, cclm_params, pred_cu->intra.multi_ref_idx, lcu);
pred_cu, cclm_params, pred_cu->intra.multi_ref_idx,
pred_cu->intra.mip_flag, pred_cu->intra.mip_is_transposed,
lcu);
// TODO: Not sure if this should be 0 or 1 but at least seems to work with 1
if (pred_cu->tr_idx > 1)
@ -361,7 +363,9 @@ static double search_intra_trdepth(encoder_state_t * const state,
x_px, y_px,
depth,
-1, chroma_mode,
pred_cu, cclm_params, 0, lcu);
pred_cu, cclm_params, 0,
pred_cu->intra.mip_flag, pred_cu->intra.mip_is_transposed,
lcu);
best_rd_cost += kvz_cu_rd_cost_chroma(state, lcu_px.x, lcu_px.y, depth, pred_cu, lcu);
}
pred_cu->tr_skip = best_tr_idx == MTS_SKIP;
@ -715,6 +719,8 @@ static int8_t search_intra_rdo(encoder_state_t * const state,
int8_t *intra_preds,
int modes_to_check,
int8_t modes[67], int8_t trafo[67], double costs[67],
int num_mip_modes,
int8_t mip_modes[32], int8_t mip_trafo[32], double mip_costs[32],
lcu_t *lcu,
uint8_t multi_ref_idx)
{
@ -749,6 +755,47 @@ static int8_t search_intra_rdo(encoder_state_t * const state,
}
}
// MIP_TODO: implement this inside the standard intra for loop. Code duplication is bad.
// MIP search
const int transp_off = num_mip_modes >> 1;
for (int mip_mode = 0; mip_mode < num_mip_modes; ++mip_mode) {
int rdo_bitcost = 0; // MIP_TODO: MIP needs own bit cost calculation
mip_costs[mip_mode] = rdo_bitcost * (int)(state->lambda + 0.5); // MIP_TODO: check if this is also correct in the case when MIP is used.
const bool is_transposed = (mip_modes[mip_mode] >= transp_off ? true : false);
// There can be 32 MIP modes, but only mode numbers [0, 15] are ever written to bitstream.
// Modes [16, 31] are indicated with the separate transpose flag.
int8_t pred_mode = (is_transposed ? mip_modes[mip_mode] - transp_off : mip_modes[mip_mode]);
// Perform transform split search and save mode RD cost for the best one.
cu_info_t pred_cu;
pred_cu.depth = depth;
pred_cu.type = CU_INTRA;
pred_cu.part_size = ((depth == MAX_PU_DEPTH) ? SIZE_NxN : SIZE_2Nx2N); // TODO: non-square blocks
pred_cu.intra.mode = pred_mode;
pred_cu.intra.mode_chroma = pred_mode;
pred_cu.intra.multi_ref_idx = 0;
pred_cu.intra.mip_is_transposed = is_transposed;
pred_cu.intra.mip_flag = true;
pred_cu.joint_cb_cr = 0;
FILL(pred_cu.cbf, 0);
// Reset transform split data in lcu.cu for this area.
kvz_lcu_fill_trdepth(lcu, x_px, y_px, depth, depth);
double mode_cost = search_intra_trdepth(state, x_px, y_px, depth, tr_depth, pred_mode, MAX_INT, &pred_cu, lcu, NULL, -1);
mip_costs[mip_mode] += mode_cost;
mip_trafo[mip_mode] = pred_cu.tr_idx;
// MIP_TODO: check if ET is viable when MIP is used
// Early termination if no coefficients has to be coded
if (state->encoder_control->cfg.intra_rdo_et && !cbf_is_set_any(pred_cu.cbf, depth)) {
modes_to_check = mip_mode + 1;
break;
}
}
for(int rdo_mode = 0; rdo_mode < modes_to_check; rdo_mode ++) {
int rdo_bitcost = kvz_luma_mode_bits(state, modes[rdo_mode], intra_preds, multi_ref_idx);
@ -762,6 +809,8 @@ static int8_t search_intra_rdo(encoder_state_t * const state,
pred_cu.intra.mode = modes[rdo_mode];
pred_cu.intra.mode_chroma = modes[rdo_mode];
pred_cu.intra.multi_ref_idx = multi_ref_idx;
pred_cu.intra.mip_is_transposed = false;
pred_cu.intra.mip_flag = false;
pred_cu.joint_cb_cr = 0;
FILL(pred_cu.cbf, 0);
@ -780,6 +829,9 @@ static int8_t search_intra_rdo(encoder_state_t * const state,
}
// Update order according to new costs
if (num_mip_modes) {
kvz_sort_modes_intra_luma(mip_modes, mip_trafo, mip_costs, num_mip_modes);
}
kvz_sort_modes_intra_luma(modes, trafo, costs, modes_to_check);
// The best transform split hierarchy is not saved anywhere, so to get the
@ -921,7 +973,7 @@ int8_t kvz_search_intra_chroma_rdo(encoder_state_t * const state,
x_px, y_px,
depth,
-1, chroma.mode, // skip luma
NULL, NULL, 0, lcu);
NULL, NULL, 0, false, false, lcu);
}
else {
@ -954,7 +1006,7 @@ int8_t kvz_search_intra_chroma_rdo(encoder_state_t * const state,
x_px, y_px,
depth,
-1, chroma.mode, // skip luma
NULL, cclm_params, 0, lcu);
NULL, cclm_params, 0, false, false, lcu);
}
chroma.cost = kvz_cu_rd_cost_chroma(state, lcu_px.x, lcu_px.y, depth, tr_cu, lcu);
@ -1044,7 +1096,9 @@ void kvz_search_cu_intra(encoder_state_t * const state,
int8_t *mode_out,
int8_t *trafo_out,
double *cost_out,
uint8_t *multi_ref_idx_out)
uint8_t *multi_ref_idx_out,
bool *mip_flag_out,
bool * mip_transposed_out)
{
const vector2d_t lcu_px = { SUB_SCU(x_px), SUB_SCU(y_px) };
const int8_t cu_width = LCU_WIDTH >> depth;
@ -1081,6 +1135,38 @@ void kvz_search_cu_intra(encoder_state_t * const state,
int8_t trafo[MAX_REF_LINE_IDX][67] = { 0 };
double costs[MAX_REF_LINE_IDX][67];
bool enable_mip = state->encoder_control->cfg.mip;
int8_t mip_modes[32]; // Modes [0, 15] are non-transposed. Modes [16,31] are transposed.
int8_t mip_trafo[32];
double mip_costs[32];
if (enable_mip) {
for (int i = 0; i < 32; ++i) {
mip_modes[i] = i;
mip_costs[i] = MAX_INT;
}
}
// The maximum number of possible MIP modes depend on block size & shape
int width = LCU_WIDTH >> depth;
int height = width; // TODO: proper height for non-square blocks.
int tmp_modes;
// MIP_TODO: check for illegal block sizes.
if (width == 4 && height == 4) {
// Mip size_id = 0. Num modes = 32
tmp_modes = 32;
}
else if (width == 4 || height == 4 || (width == 8 && height == 8)) {
// Mip size_id = 0. Num modes = 16
tmp_modes = 16;
}
else {
// Mip size_id = 0. Num modes = 12
tmp_modes = 12;
}
uint8_t num_mip_modes = enable_mip ? tmp_modes : 0;
// Find best intra mode for 2Nx2N.
kvz_pixel *ref_pixels = &lcu->ref.y[lcu_px.x + lcu_px.y * LCU_WIDTH];
@ -1132,24 +1218,37 @@ void kvz_search_cu_intra(encoder_state_t * const state,
}
for(int8_t line = 0; line < lines; ++line) {
// For extra reference lines, only check predicted modes
// For extra reference lines, only check predicted modes & no MIP search.
if (line != 0) {
number_of_modes_to_search = 0;
num_mip_modes = 0;
}
int num_modes_to_check = MIN(number_of_modes[line], number_of_modes_to_search);
kvz_sort_modes(modes[line], costs[line], number_of_modes[line]);
// TODO: if rough search is implemented for MIP, sort mip_modes here.
number_of_modes[line] = search_intra_rdo(state,
x_px, y_px, depth,
ref_pixels, LCU_WIDTH,
candidate_modes,
num_modes_to_check,
modes[line], trafo[line], costs[line], lcu, line);
modes[line], trafo[line], costs[line],
num_mip_modes,
mip_modes, mip_trafo, mip_costs,
lcu, line);
}
}
uint8_t best_line = 0;
double best_line_mode_cost = costs[0][0];
uint8_t best_mip_mode_idx = 0;
uint8_t best_mode_indices[MAX_REF_LINE_IDX];
int8_t tmp_best_mode;
int8_t tmp_best_trafo;
double tmp_best_cost;
bool tmp_mip_flag = false;
bool tmp_mip_transp = false;
for (int line = 0; line < lines; ++line) {
best_mode_indices[line] = select_best_mode_index(modes[line], costs[line], number_of_modes[line]);
if (best_line_mode_cost > costs[line][best_mode_indices[line]]) {
@ -1158,8 +1257,25 @@ void kvz_search_cu_intra(encoder_state_t * const state,
}
}
*mode_out = modes[best_line][best_mode_indices[best_line]];
*trafo_out = trafo[best_line][best_mode_indices[best_line]];
*cost_out = costs[best_line][best_mode_indices[best_line]];
*multi_ref_idx_out = best_line;
tmp_best_mode = modes[best_line][best_mode_indices[best_line]];
tmp_best_trafo = trafo[best_line][best_mode_indices[best_line]];
tmp_best_cost = costs[best_line][best_mode_indices[best_line]];
if (num_mip_modes) {
best_mip_mode_idx = select_best_mode_index(mip_modes, mip_costs, num_mip_modes);
if (tmp_best_cost > mip_costs[best_mip_mode_idx]) {
tmp_best_mode = mip_modes[best_mip_mode_idx];
tmp_best_trafo = mip_trafo[best_mip_mode_idx];
tmp_best_cost = mip_costs[best_mip_mode_idx];
tmp_mip_flag = true;
tmp_mip_transp = (tmp_best_mode >= (num_mip_modes >> 1)) ? 1 : 0;
}
}
*mode_out = tmp_best_mode;
*trafo_out = tmp_best_trafo;
*cost_out = tmp_best_cost;
*mip_flag_out = tmp_mip_flag;
*mip_transposed_out = tmp_mip_transp;
*multi_ref_idx_out = tmp_mip_flag ? 0 : best_line;
}

View file

@ -60,6 +60,8 @@ void kvz_search_cu_intra(encoder_state_t * const state,
int8_t *mode_out,
int8_t *trafo_out,
double *cost_out,
uint8_t *multi_ref_idx_out);
uint8_t *multi_ref_idx_out,
bool *mip_flag,
bool *mip_transp);
#endif // SEARCH_INTRA_H_