diff --git a/src/search.c b/src/search.c index 0fe9a505..88e1ca9d 100644 --- a/src/search.c +++ b/src/search.c @@ -1074,19 +1074,21 @@ static void sort_modes(int8_t *modes, double *costs, int length) } -static unsigned get_cost(pixel *pred, pixel *orig_block, cost_pixel_nxn_func *satd_func, cost_pixel_nxn_func *sad_func, int width) +static double get_cost(encoder_state * const encoder_state, pixel *pred, pixel *orig_block, cost_pixel_nxn_func *satd_func, cost_pixel_nxn_func *sad_func, int width) { - unsigned cost = satd_func(pred, orig_block); + double satd_cost = satd_func(pred, orig_block); if (MN != 0 && width == 4) { // If the mode looks better with SAD than SATD it might be a good // candidate for transform skip. How much better SAD has to be is // controlled by MN. - unsigned sad_cost = MN * sad_func(pred, orig_block); - if (sad_cost < cost) { - cost = sad_cost; + const cabac_ctx *ctx = &encoder_state->cabac.ctx.transform_skip_model_luma; + double trskip_cost = encoder_state->global->cur_lambda_cost_sqrt * (CTX_ENTROPY_FBITS(ctx, 1) - CTX_ENTROPY_FBITS(ctx, 0)); + double sad_cost = MN * sad_func(pred, orig_block) + trskip_cost; + if (sad_cost < satd_cost) { + return sad_cost; } } - return cost; + return satd_cost; } @@ -1148,7 +1150,7 @@ static int8_t search_intra_rough(encoder_state * const encoder_state, // the recursive search. for (int mode = 2; mode <= 34; mode += offset) { intra_get_pred(encoder_state->encoder_control, ref, recstride, pred, width, mode, 0); - costs[modes_selected] = get_cost(pred, orig_block, satd_func, sad_func, width); + costs[modes_selected] = get_cost(encoder_state, pred, orig_block, satd_func, sad_func, width); modes[modes_selected] = mode; min_cost = MIN(min_cost, costs[modes_selected]); @@ -1168,7 +1170,7 @@ static int8_t search_intra_rough(encoder_state * const encoder_state, int8_t mode = modes[0] - offset; if (mode >= 2) { intra_get_pred(encoder_state->encoder_control, ref, recstride, pred, width, mode, 0); - costs[modes_selected] = get_cost(pred, orig_block, satd_func, sad_func, width); + costs[modes_selected] = get_cost(encoder_state, pred, orig_block, satd_func, sad_func, width); modes[modes_selected] = mode; ++modes_selected; } @@ -1176,7 +1178,7 @@ static int8_t search_intra_rough(encoder_state * const encoder_state, mode = modes[0] + offset; if (mode <= 34) { intra_get_pred(encoder_state->encoder_control, ref, recstride, pred, width, mode, 0); - costs[modes_selected] = get_cost(pred, orig_block, satd_func, sad_func, width); + costs[modes_selected] = get_cost(encoder_state, pred, orig_block, satd_func, sad_func, width); modes[modes_selected] = mode; ++modes_selected; } @@ -1199,7 +1201,7 @@ static int8_t search_intra_rough(encoder_state * const encoder_state, if (!has_mode) { intra_get_pred(encoder_state->encoder_control, ref, recstride, pred, width, mode, 0); - costs[modes_selected] = get_cost(pred, orig_block, satd_func, sad_func, width); + costs[modes_selected] = get_cost(encoder_state, pred, orig_block, satd_func, sad_func, width); modes[modes_selected] = mode; ++modes_selected; }