Fix distortion and bits

This commit is contained in:
Joose Sainio 2024-01-31 13:03:33 +02:00
parent 6dfa89d3a6
commit 24527c5be8

View file

@ -748,7 +748,8 @@ static double cu_rd_cost_tr_split_accurate(
uint8_t isp_cbf,
const cu_loc_t* const cu_loc,
const cu_loc_t* const chroma_loc,
bool has_chroma) {
bool has_chroma,
double *bits) {
const int width = cu_loc->width;
const int height = cu_loc->height; // TODO: height for non-square blocks
@ -793,8 +794,9 @@ static double cu_rd_cost_tr_split_accurate(
uvg_get_split_locs(chroma_loc, split, split_chroma_cu_loc, NULL);
}
for (int i = 0; i < split_count; ++i) {
sum += cu_rd_cost_tr_split_accurate(state, pred_cu, lcu, tree_type, isp_cbf, &split_cu_loc[i], chroma_loc ? &split_chroma_cu_loc[i] : NULL, has_chroma);
sum += cu_rd_cost_tr_split_accurate(state, pred_cu, lcu, tree_type, isp_cbf, &split_cu_loc[i], chroma_loc ? &split_chroma_cu_loc[i] : NULL, has_chroma, bits);
}
if (bits) *bits += luma_bits;
return sum + luma_bits * state->lambda;
}
@ -998,8 +1000,9 @@ static double cu_rd_cost_tr_split_accurate(
tr_cu->violates_mts_coeff_constraint = false;
}
double bits = luma_bits + coeff_bits;
return luma_ssd * UVG_LUMA_MULT + chroma_ssd * UVG_CHROMA_MULT + (bits + chroma_bits) * state->lambda;
luma_bits += coeff_bits;
if (bits) *bits += luma_bits + chroma_bits;
return luma_ssd * UVG_LUMA_MULT + chroma_ssd * UVG_CHROMA_MULT + (luma_bits + chroma_bits) * state->lambda;
}
@ -1655,69 +1658,7 @@ static double search_cu(
UVG_GET_TIME(&end_time);
}
if (cur_cu->type != CU_NOTSET) {
uint8_t type = 0;
uint8_t buffer[8192];
UVG_CLOCK_T time;
UVG_GET_TIME(&time);
#ifdef _MSC_VER
uint64_t time_high = time.dwHighDateTime & 0x000fffff;
time_high <<= 32;
uint64_t time_low = time.dwLowDateTime;
uint64_t timestamp = time_high | time_low;
timestamp *= 100;
#else
uint64_t timestamp = time.tv_sec * 1000000000 + time.tv_nsec;
#endif
uint64_t bytes = 0;
memcpy(buffer, &type, 1); bytes++;
memcpy(buffer + bytes, &timestamp, 8); bytes+=8;
memcpy(buffer + bytes, &state->frame->num, 1); bytes++;
memcpy(buffer + bytes, &cu_loc->x, 2); bytes+=2;
memcpy(buffer + bytes, &cu_loc->y, 2); bytes+=2;
memcpy(buffer + bytes, &cu_loc->width, 1); bytes++;
memcpy(buffer + bytes, &cu_loc->height, 1); bytes++;
memcpy(buffer + bytes, &cur_cu->split_tree, 4); bytes+=4;
memcpy(buffer + bytes, &cur_cu->qp, 1); bytes++;
memcpy(buffer + bytes, &cur_cu->intra.mode, 1); bytes++;
memcpy(buffer + bytes, &cur_cu->intra.mip_flag, 1); bytes++;
memcpy(buffer + bytes, &cur_cu->intra.mip_is_transposed, 1); bytes++;
memcpy(buffer + bytes, &cur_cu->intra.multi_ref_idx, 1); bytes++;
memcpy(buffer + bytes, &cur_cu->intra.isp_mode, 1); bytes++;
memcpy(buffer + bytes, &cur_cu->lfnst_idx, 1); bytes++;
memcpy(buffer + bytes, &cur_cu->tr_idx, 1); bytes++;
float float_cost = (float)intra_search.cost;
float float_bits = (float)intra_search.bits;
float float_dist = (float)intra_search.distortion;
memcpy(buffer + bytes, &float_cost, 4); bytes+=4;
memcpy(buffer + bytes, &float_bits, 4); bytes+=4;
memcpy(buffer + bytes, &float_dist, 4); bytes+=4;
uvg_pixels_blit(&lcu->rec.y[x_local + y_local * LCU_WIDTH], buffer + bytes, cu_width, cu_height, LCU_WIDTH, cu_width);
bytes += cu_loc->width*cu_loc->height;
uvg_pixels_blit(&lcu->rec.u[x_local/2 + y_local/2 * LCU_WIDTH_C], buffer + bytes, cu_width / 2, cu_height / 2, LCU_WIDTH_C, cu_width / 2);
bytes += cu_loc->width*cu_loc->height / 4;
uvg_pixels_blit(&lcu->rec.v[x_local/2 + y_local/2 * LCU_WIDTH_C], buffer + bytes, cu_width / 2, cu_height / 2, LCU_WIDTH_C, cu_width / 2);
bytes += cu_loc->width*cu_loc->height / 4;
zmq_send(state->send_socket, buffer, bytes, 0);
if (state->frame->cfg->speed > 1) {
#if defined(__GNUC__) && !defined(__MINGW32__)
struct timespec sleep_time;
sleep_time.tv_sec = 0;
sleep_time.tv_nsec = UVG_CLOCK_T_DIFF(start_time, end_time) * 1e9 * (state->frame->cfg->speed - 1);
nanosleep(&sleep_time, NULL);
#else
usleep(
UVG_CLOCK_T_DIFF(start_time, end_time) * 1e6 *
(state->frame->cfg->speed - 1));
#endif
}
}
// The cabac functions assume chroma locations whereas the search uses luma locations
// for the chroma tree, therefore we need to shift the chroma coordinates here for
// passing to the bit cost calculating functions.
@ -1727,6 +1668,8 @@ static double search_cu(
separate_tree_chroma_loc.width >>= 1;
separate_tree_chroma_loc.height >>= 1;
double block_bits = 0;
if (cur_cu->type == CU_INTRA || cur_cu->type == CU_INTER || cur_cu->type == CU_IBC) {
double bits = 0;
cabac_data_t* cabac = &state->search_cabac;
@ -1745,7 +1688,69 @@ static double search_cu(
cost = bits * state->lambda;
cost += cu_rd_cost_tr_split_accurate(state, cur_cu, lcu, tree_type, intra_search.best_isp_cbfs, cu_loc, chroma_loc, has_chroma);
cost += cu_rd_cost_tr_split_accurate(state, cur_cu, lcu, tree_type, intra_search.best_isp_cbfs, cu_loc, chroma_loc, has_chroma, &bits);
block_bits = bits;
{
uint8_t type = 0;
uint8_t buffer[8192];
UVG_CLOCK_T time;
UVG_GET_TIME(&time);
#ifdef _MSC_VER
uint64_t time_high = time.dwHighDateTime & 0x000fffff;
time_high <<= 32;
uint64_t time_low = time.dwLowDateTime;
uint64_t timestamp = time_high | time_low;
timestamp *= 100;
#else
uint64_t timestamp = time.tv_sec * 1000000000 + time.tv_nsec;
#endif
uint64_t bytes = 0;
memcpy(buffer, &type, 1); bytes++;
memcpy(buffer + bytes, &timestamp, 8); bytes+=8;
memcpy(buffer + bytes, &state->frame->num, 1); bytes++;
memcpy(buffer + bytes, &cu_loc->x, 2); bytes+=2;
memcpy(buffer + bytes, &cu_loc->y, 2); bytes+=2;
memcpy(buffer + bytes, &cu_loc->width, 1); bytes++;
memcpy(buffer + bytes, &cu_loc->height, 1); bytes++;
memcpy(buffer + bytes, &cur_cu->split_tree, 4); bytes+=4;
memcpy(buffer + bytes, &cur_cu->qp, 1); bytes++;
memcpy(buffer + bytes, &cur_cu->intra.mode, 1); bytes++;
memcpy(buffer + bytes, &cur_cu->intra.mip_flag, 1); bytes++;
memcpy(buffer + bytes, &cur_cu->intra.mip_is_transposed, 1); bytes++;
memcpy(buffer + bytes, &cur_cu->intra.multi_ref_idx, 1); bytes++;
memcpy(buffer + bytes, &cur_cu->intra.isp_mode, 1); bytes++;
memcpy(buffer + bytes, &cur_cu->lfnst_idx, 1); bytes++;
memcpy(buffer + bytes, &cur_cu->tr_idx, 1); bytes++;
float float_cost = (float)cost;
float float_bits = (float)bits;
float float_dist = (float)(cost - bits * state->lambda);
memcpy(buffer + bytes, &float_cost, 4); bytes+=4;
memcpy(buffer + bytes, &float_bits, 4); bytes+=4;
memcpy(buffer + bytes, &float_dist, 4); bytes+=4;
uvg_pixels_blit(&lcu->rec.y[x_local + y_local * LCU_WIDTH], buffer + bytes, cu_width, cu_height, LCU_WIDTH, cu_width);
bytes += cu_loc->width*cu_loc->height;
uvg_pixels_blit(&lcu->rec.u[x_local/2 + y_local/2 * LCU_WIDTH_C], buffer + bytes, cu_width / 2, cu_height / 2, LCU_WIDTH_C, cu_width / 2);
bytes += cu_loc->width*cu_loc->height / 4;
uvg_pixels_blit(&lcu->rec.v[x_local/2 + y_local/2 * LCU_WIDTH_C], buffer + bytes, cu_width / 2, cu_height / 2, LCU_WIDTH_C, cu_width / 2);
bytes += cu_loc->width*cu_loc->height / 4;
zmq_send(state->send_socket, buffer, bytes, 0);
if (state->frame->cfg->speed > 1) {
#if defined(__GNUC__) && !defined(__MINGW32__)
struct timespec sleep_time;
sleep_time.tv_sec = 0;
sleep_time.tv_nsec = UVG_CLOCK_T_DIFF(start_time, end_time) * 1e9 * (state->frame->cfg->speed - 1);
nanosleep(&sleep_time, NULL);
#else
usleep(
UVG_CLOCK_T_DIFF(start_time, end_time) * 1e6 *
(state->frame->cfg->speed - 1));
#endif
}
}
//fprintf(stderr, "%4d %4d %2d %2d %d %d %f\n", x, y, cu_width, cu_height, has_chroma, cur_cu->split_tree, cost);
//if (ctrl->cfg.zero_coeff_rdo && inter_zero_coeff_cost <= cost) {
@ -2035,7 +2040,7 @@ static double search_cu(
double mode_bits = calc_mode_bits(state, lcu, cur_cu, cu_loc) + bits;
cost += mode_bits * state->lambda;
cost += cu_rd_cost_tr_split_accurate(state, cur_cu, lcu, tree_type, 0, cu_loc, chroma_loc, has_chroma);
cost += cu_rd_cost_tr_split_accurate(state, cur_cu, lcu, tree_type, 0, cu_loc, chroma_loc, has_chroma, &bits);
mark_deblocking(cu_loc, chroma_loc, lcu, tree_type, has_chroma, is_separate_tree, x_local, y_local);
@ -2076,8 +2081,8 @@ static double search_cu(
memcpy(buffer + bytes, &cur_cu->lfnst_idx, 1); bytes++;
memcpy(buffer + bytes, &cur_cu->tr_idx, 1); bytes++;
float float_cost = (float)cost;
float float_bits = (float)((cost - intra_search.distortion) / state->lambda);
float float_dist = (float)intra_search.distortion;
float float_bits = (float)block_bits;
float float_dist = (float)(cost - block_bits * state->lambda);
memcpy(buffer + bytes, &float_cost, 4); bytes+=4;
memcpy(buffer + bytes, &float_bits, 4); bytes+=4;
memcpy(buffer + bytes, &float_dist, 4); bytes+=4;