ref: caca57251642b5fc2a71dee3034cde300ff0c621
parent: ad4d1c4383b7705157807f2c364fca3ee3d713ef
author: Henrik Gramner <[email protected]>
date: Tue Feb 5 10:34:08 EST 2019
mc: Ensure high bitdepth intermediates fits in int16_t An extreme edge case with the combination of 8-tap sharp_sharp, mx and my both around 8, and a very specific pixel input pattern can cause overflows. Add code to checkasm to trigger this scenario.
--- a/src/mc_tmpl.c
+++ b/src/mc_tmpl.c
@@ -39,9 +39,14 @@
#if BITDEPTH == 8
#define get_intermediate_bits(bitdepth_max) 4
+// Output in interval [-5132, 9212], fits in int16_t as is
+#define PREP_BIAS 0
#else
// 4 for 10 bits/component, 2 for 12 bits/component
#define get_intermediate_bits(bitdepth_max) (14 - bitdepth_from_max(bitdepth_max))
+// Output in interval [-20588, 36956] (10-bit), [-20602, 36983] (12-bit)
+// Subtract a bias to ensure the output fits in int16_t
+#define PREP_BIAS 8192
#endif
static NOINLINE void
@@ -63,7 +68,7 @@
const int intermediate_bits = get_intermediate_bits(bitdepth_max);
do {
for (int x = 0; x < w; x++)
- tmp[x] = src[x] << intermediate_bits;
+ tmp[x] = (src[x] << intermediate_bits) - PREP_BIAS;
tmp += w;
src += src_stride;
@@ -237,8 +242,12 @@
mid_ptr = mid + 128 * 3;
do {
- for (int x = 0; x < w; x++)
- tmp[x] = DAV1D_FILTER_8TAP_RND(mid_ptr, x, fv, 128, 6);
+ for (int x = 0; x < w; x++) {
+ int t = DAV1D_FILTER_8TAP_RND(mid_ptr, x, fv, 128, 6) -
+ PREP_BIAS;
+ assert(t >= INT16_MIN && t <= INT16_MAX);
+ tmp[x] = t;
+ }
mid_ptr += 128;
tmp += w;
@@ -247,7 +256,8 @@
do {
for (int x = 0; x < w; x++)
tmp[x] = DAV1D_FILTER_8TAP_RND(src, x, fh, 1,
- 6 - intermediate_bits);
+ 6 - intermediate_bits) -
+ PREP_BIAS;
tmp += w;
src += src_stride;
@@ -257,7 +267,8 @@
do {
for (int x = 0; x < w; x++)
tmp[x] = DAV1D_FILTER_8TAP_RND(src, x, fv, src_stride,
- 6 - intermediate_bits);
+ 6 - intermediate_bits) -
+ PREP_BIAS;
tmp += w;
src += src_stride;
@@ -302,7 +313,8 @@
GET_V_FILTER(my >> 6);
for (x = 0; x < w; x++)
- tmp[x] = fv ? DAV1D_FILTER_8TAP_RND(mid_ptr, x, fv, 128, 6) : mid_ptr[x];
+ tmp[x] = (fv ? DAV1D_FILTER_8TAP_RND(mid_ptr, x, fv, 128, 6)
+ : mid_ptr[x]) - PREP_BIAS;
my += dy;
mid_ptr += (my >> 10) * 128;
@@ -499,7 +511,8 @@
mid_ptr = mid;
do {
for (int x = 0; x < w; x++)
- tmp[x] = FILTER_BILIN_RND(mid_ptr, x, my, 128, 4);
+ tmp[x] = FILTER_BILIN_RND(mid_ptr, x, my, 128, 4) -
+ PREP_BIAS;
mid_ptr += 128;
tmp += w;
@@ -508,7 +521,8 @@
do {
for (int x = 0; x < w; x++)
tmp[x] = FILTER_BILIN_RND(src, x, mx, 1,
- 4 - intermediate_bits);
+ 4 - intermediate_bits) -
+ PREP_BIAS;
tmp += w;
src += src_stride;
@@ -518,7 +532,7 @@
do {
for (int x = 0; x < w; x++)
tmp[x] = FILTER_BILIN_RND(src, x, my, src_stride,
- 4 - intermediate_bits);
+ 4 - intermediate_bits) - PREP_BIAS;
tmp += w;
src += src_stride;
@@ -557,7 +571,7 @@
int x;
for (x = 0; x < w; x++)
- tmp[x] = FILTER_BILIN_RND(mid_ptr, x, my >> 6, 128, 4);
+ tmp[x] = FILTER_BILIN_RND(mid_ptr, x, my >> 6, 128, 4) - PREP_BIAS;
my += dy;
mid_ptr += (my >> 10) * 128;
@@ -571,7 +585,8 @@
HIGHBD_DECL_SUFFIX)
{
const int intermediate_bits = get_intermediate_bits(bitdepth_max);
- const int sh = intermediate_bits + 1, rnd = 1 << intermediate_bits;
+ const int sh = intermediate_bits + 1;
+ const int rnd = (1 << intermediate_bits) + PREP_BIAS * 2;
do {
for (int x = 0; x < w; x++)
dst[x] = iclip_pixel((tmp1[x] + tmp2[x] + rnd) >> sh);
@@ -587,7 +602,8 @@
const int weight HIGHBD_DECL_SUFFIX)
{
const int intermediate_bits = get_intermediate_bits(bitdepth_max);
- const int sh = intermediate_bits + 4, rnd = 8 << intermediate_bits;
+ const int sh = intermediate_bits + 4;
+ const int rnd = (8 << intermediate_bits) + PREP_BIAS * 16;
do {
for (int x = 0; x < w; x++)
dst[x] = iclip_pixel((tmp1[x] * weight +
@@ -604,7 +620,8 @@
const uint8_t *mask HIGHBD_DECL_SUFFIX)
{
const int intermediate_bits = get_intermediate_bits(bitdepth_max);
- const int sh = intermediate_bits + 6, rnd = 32 << intermediate_bits;
+ const int sh = intermediate_bits + 6;
+ const int rnd = (32 << intermediate_bits) + PREP_BIAS * 64;
do {
for (int x = 0; x < w; x++)
dst[x] = iclip_pixel((tmp1[x] * mask[x] +
@@ -668,7 +685,8 @@
// and then load this intermediate to calculate final value for odd rows
const int intermediate_bits = get_intermediate_bits(bitdepth_max);
const int bitdepth = bitdepth_from_max(bitdepth_max);
- const int sh = intermediate_bits + 6, rnd = 32 << intermediate_bits;
+ const int sh = intermediate_bits + 6;
+ const int rnd = (32 << intermediate_bits) + PREP_BIAS * 64;
const int mask_sh = bitdepth + intermediate_bits - 4;
const int mask_rnd = 1 << (mask_sh - 5);
do {
@@ -797,7 +815,7 @@
const int8_t *const filter =
dav1d_mc_warp_filter[64 + ((tmy + 512) >> 10)];
- tmp[x] = FILTER_WARP_RND(mid_ptr, x, filter, 8, 7);
+ tmp[x] = FILTER_WARP_RND(mid_ptr, x, filter, 8, 7) - PREP_BIAS;
}
mid_ptr += 8;
tmp += tmp_stride;
--- a/tests/checkasm/mc.c
+++ b/tests/checkasm/mc.c
@@ -84,6 +84,17 @@
report("mc");
}
+/* Generate worst case input in the topleft corner, randomize the rest */
+static void generate_mct_input(pixel *const buf, const int bitdepth_max) {
+ static const int8_t pattern[8] = { -1, 0, -1, 0, 0, -1, 0, -1 };
+ const int sign = -(rnd() & 1);
+
+ for (int y = 0; y < 135; y++)
+ for (int x = 0; x < 135; x++)
+ buf[135*y+x] = ((x | y) < 8 ? (pattern[x] ^ pattern[y] ^ sign)
+ : rnd()) & bitdepth_max;
+}
+
static void check_mct(Dav1dMCDSPContext *const c) {
ALIGN_STK_32(pixel, src_buf, 135 * 135,);
ALIGN_STK_32(int16_t, c_tmp, 128 * 128,);
@@ -107,10 +118,8 @@
#else
const int bitdepth_max = 0xff;
#endif
+ generate_mct_input(src_buf, bitdepth_max);
- for (int i = 0; i < 135 * 135; i++)
- src_buf[i] = rnd() & bitdepth_max;
-
call_ref(c_tmp, src, w, w, h, mx, my HIGHBD_TAIL_SUFFIX);
call_new(a_tmp, src, w, w, h, mx, my HIGHBD_TAIL_SUFFIX);
if (memcmp(c_tmp, a_tmp, w * h * sizeof(*c_tmp)))
@@ -127,12 +136,10 @@
int16_t (*const tmp)[128 * 128], const int bitdepth_max)
{
for (int i = 0; i < 2; i++) {
- for (int j = 0; j < 135 * 135; j++)
- buf[j] = rnd() & bitdepth_max;
- c->mct[rnd() % N_2D_FILTERS](tmp[i], buf + 135 * 3 + 3,
- 128 * sizeof(pixel), 128, 128,
- rnd() & 15, rnd() & 15
- HIGHBD_TAIL_SUFFIX);
+ generate_mct_input(buf, bitdepth_max);
+ c->mct[FILTER_2D_8TAP_SHARP](tmp[i], buf + 135 * 3 + 3,
+ 135 * sizeof(pixel), 128, 128,
+ 8, 8 HIGHBD_TAIL_SUFFIX);
}
}