shithub: dav1d

Download patch

ref: caca57251642b5fc2a71dee3034cde300ff0c621
parent: ad4d1c4383b7705157807f2c364fca3ee3d713ef
author: Henrik Gramner <[email protected]>
date: Tue Feb 5 10:34:08 EST 2019

mc: Ensure high bitdepth intermediates fits in int16_t

An extreme edge case with the combination of 8-tap sharp_sharp, mx and my
both around 8, and a very specific pixel input pattern can cause overflows.

Add code to checkasm to trigger this scenario.

--- a/src/mc_tmpl.c
+++ b/src/mc_tmpl.c
@@ -39,9 +39,14 @@
 
 #if BITDEPTH == 8
 #define get_intermediate_bits(bitdepth_max) 4
+// Output in interval [-5132, 9212], fits in int16_t as is
+#define PREP_BIAS 0
 #else
 // 4 for 10 bits/component, 2 for 12 bits/component
 #define get_intermediate_bits(bitdepth_max) (14 - bitdepth_from_max(bitdepth_max))
+// Output in interval [-20588, 36956] (10-bit), [-20602, 36983] (12-bit)
+// Subtract a bias to ensure the output fits in int16_t
+#define PREP_BIAS 8192
 #endif
 
 static NOINLINE void
@@ -63,7 +68,7 @@
     const int intermediate_bits = get_intermediate_bits(bitdepth_max);
     do {
         for (int x = 0; x < w; x++)
-            tmp[x] = src[x] << intermediate_bits;
+            tmp[x] = (src[x] << intermediate_bits) - PREP_BIAS;
 
         tmp += w;
         src += src_stride;
@@ -237,8 +242,12 @@
 
             mid_ptr = mid + 128 * 3;
             do {
-                for (int x = 0; x < w; x++)
-                    tmp[x] = DAV1D_FILTER_8TAP_RND(mid_ptr, x, fv, 128, 6);
+                for (int x = 0; x < w; x++) {
+                    int t = DAV1D_FILTER_8TAP_RND(mid_ptr, x, fv, 128, 6) -
+                                  PREP_BIAS;
+                    assert(t >= INT16_MIN && t <= INT16_MAX);
+                    tmp[x] = t;
+                }
 
                 mid_ptr += 128;
                 tmp += w;
@@ -247,7 +256,8 @@
             do {
                 for (int x = 0; x < w; x++)
                     tmp[x] = DAV1D_FILTER_8TAP_RND(src, x, fh, 1,
-                                                   6 - intermediate_bits);
+                                                   6 - intermediate_bits) -
+                             PREP_BIAS;
 
                 tmp += w;
                 src += src_stride;
@@ -257,7 +267,8 @@
         do {
             for (int x = 0; x < w; x++)
                 tmp[x] = DAV1D_FILTER_8TAP_RND(src, x, fv, src_stride,
-                                               6 - intermediate_bits);
+                                               6 - intermediate_bits) -
+                         PREP_BIAS;
 
             tmp += w;
             src += src_stride;
@@ -302,7 +313,8 @@
         GET_V_FILTER(my >> 6);
 
         for (x = 0; x < w; x++)
-            tmp[x] = fv ? DAV1D_FILTER_8TAP_RND(mid_ptr, x, fv, 128, 6) : mid_ptr[x];
+            tmp[x] = (fv ? DAV1D_FILTER_8TAP_RND(mid_ptr, x, fv, 128, 6)
+                         : mid_ptr[x]) - PREP_BIAS;
 
         my += dy;
         mid_ptr += (my >> 10) * 128;
@@ -499,7 +511,8 @@
             mid_ptr = mid;
             do {
                 for (int x = 0; x < w; x++)
-                    tmp[x] = FILTER_BILIN_RND(mid_ptr, x, my, 128, 4);
+                    tmp[x] = FILTER_BILIN_RND(mid_ptr, x, my, 128, 4) -
+                             PREP_BIAS;
 
                 mid_ptr += 128;
                 tmp += w;
@@ -508,7 +521,8 @@
             do {
                 for (int x = 0; x < w; x++)
                     tmp[x] = FILTER_BILIN_RND(src, x, mx, 1,
-                                              4 - intermediate_bits);
+                                              4 - intermediate_bits) -
+                             PREP_BIAS;
 
                 tmp += w;
                 src += src_stride;
@@ -518,7 +532,7 @@
         do {
             for (int x = 0; x < w; x++)
                 tmp[x] = FILTER_BILIN_RND(src, x, my, src_stride,
-                                          4 - intermediate_bits);
+                                          4 - intermediate_bits) - PREP_BIAS;
 
             tmp += w;
             src += src_stride;
@@ -557,7 +571,7 @@
         int x;
 
         for (x = 0; x < w; x++)
-            tmp[x] = FILTER_BILIN_RND(mid_ptr, x, my >> 6, 128, 4);
+            tmp[x] = FILTER_BILIN_RND(mid_ptr, x, my >> 6, 128, 4) - PREP_BIAS;
 
         my += dy;
         mid_ptr += (my >> 10) * 128;
@@ -571,7 +585,8 @@
                   HIGHBD_DECL_SUFFIX)
 {
     const int intermediate_bits = get_intermediate_bits(bitdepth_max);
-    const int sh = intermediate_bits + 1, rnd = 1 << intermediate_bits;
+    const int sh = intermediate_bits + 1;
+    const int rnd = (1 << intermediate_bits) + PREP_BIAS * 2;
     do {
         for (int x = 0; x < w; x++)
             dst[x] = iclip_pixel((tmp1[x] + tmp2[x] + rnd) >> sh);
@@ -587,7 +602,8 @@
                     const int weight HIGHBD_DECL_SUFFIX)
 {
     const int intermediate_bits = get_intermediate_bits(bitdepth_max);
-    const int sh = intermediate_bits + 4, rnd = 8 << intermediate_bits;
+    const int sh = intermediate_bits + 4;
+    const int rnd = (8 << intermediate_bits) + PREP_BIAS * 16;
     do {
         for (int x = 0; x < w; x++)
             dst[x] = iclip_pixel((tmp1[x] * weight +
@@ -604,7 +620,8 @@
                    const uint8_t *mask HIGHBD_DECL_SUFFIX)
 {
     const int intermediate_bits = get_intermediate_bits(bitdepth_max);
-    const int sh = intermediate_bits + 6, rnd = 32 << intermediate_bits;
+    const int sh = intermediate_bits + 6;
+    const int rnd = (32 << intermediate_bits) + PREP_BIAS * 64;
     do {
         for (int x = 0; x < w; x++)
             dst[x] = iclip_pixel((tmp1[x] * mask[x] +
@@ -668,7 +685,8 @@
     // and then load this intermediate to calculate final value for odd rows
     const int intermediate_bits = get_intermediate_bits(bitdepth_max);
     const int bitdepth = bitdepth_from_max(bitdepth_max);
-    const int sh = intermediate_bits + 6, rnd = 32 << intermediate_bits;
+    const int sh = intermediate_bits + 6;
+    const int rnd = (32 << intermediate_bits) + PREP_BIAS * 64;
     const int mask_sh = bitdepth + intermediate_bits - 4;
     const int mask_rnd = 1 << (mask_sh - 5);
     do {
@@ -797,7 +815,7 @@
             const int8_t *const filter =
                 dav1d_mc_warp_filter[64 + ((tmy + 512) >> 10)];
 
-            tmp[x] = FILTER_WARP_RND(mid_ptr, x, filter, 8, 7);
+            tmp[x] = FILTER_WARP_RND(mid_ptr, x, filter, 8, 7) - PREP_BIAS;
         }
         mid_ptr += 8;
         tmp += tmp_stride;
--- a/tests/checkasm/mc.c
+++ b/tests/checkasm/mc.c
@@ -84,6 +84,17 @@
     report("mc");
 }
 
+/* Generate worst case input in the topleft corner, randomize the rest */
+static void generate_mct_input(pixel *const buf, const int bitdepth_max) {
+    static const int8_t pattern[8] = { -1,  0, -1,  0,  0, -1,  0, -1 };
+    const int sign = -(rnd() & 1);
+
+    for (int y = 0; y < 135; y++)
+        for (int x = 0; x < 135; x++)
+            buf[135*y+x] = ((x | y) < 8 ? (pattern[x] ^ pattern[y] ^ sign)
+                                        : rnd()) & bitdepth_max;
+}
+
 static void check_mct(Dav1dMCDSPContext *const c) {
     ALIGN_STK_32(pixel, src_buf, 135 * 135,);
     ALIGN_STK_32(int16_t, c_tmp,   128 * 128,);
@@ -107,10 +118,8 @@
 #else
                         const int bitdepth_max = 0xff;
 #endif
+                        generate_mct_input(src_buf, bitdepth_max);
 
-                        for (int i = 0; i < 135 * 135; i++)
-                            src_buf[i] = rnd() & bitdepth_max;
-
                         call_ref(c_tmp, src, w, w, h, mx, my HIGHBD_TAIL_SUFFIX);
                         call_new(a_tmp, src, w, w, h, mx, my HIGHBD_TAIL_SUFFIX);
                         if (memcmp(c_tmp, a_tmp, w * h * sizeof(*c_tmp)))
@@ -127,12 +136,10 @@
                      int16_t (*const tmp)[128 * 128], const int bitdepth_max)
 {
     for (int i = 0; i < 2; i++) {
-        for (int j = 0; j < 135 * 135; j++)
-            buf[j] = rnd() & bitdepth_max;
-        c->mct[rnd() % N_2D_FILTERS](tmp[i], buf + 135 * 3 + 3,
-                                      128 * sizeof(pixel), 128, 128,
-                                      rnd() & 15, rnd() & 15
-                                      HIGHBD_TAIL_SUFFIX);
+        generate_mct_input(buf, bitdepth_max);
+        c->mct[FILTER_2D_8TAP_SHARP](tmp[i], buf + 135 * 3 + 3,
+                                      135 * sizeof(pixel), 128, 128,
+                                      8, 8 HIGHBD_TAIL_SUFFIX);
     }
 }