shithub: dav1d

Download patch

ref: c59f19405362091741f441ff1a98810955a56a3f
parent: 2532642bbbdfcc77140846e1403a6b393eaba974
author: Rupert Swarbrick <[email protected]>
date: Wed Oct 17 13:49:35 EDT 2018

Correctly flush at the end of OBUs

This fixes failures when an OBU has more than a byte's worth of
trailing zeros.

As part of this work, it also rejigs the dav1d_flush_get_bits function
slightly. This worked before, but it wasn't very obvious why (it
worked because bits_left was never more than 7). This patch renames it
to dav1d_bytealign_get_bits, which makes it clearer what it does and
adds a comment explaining why it works properly.

The new dav1d_bytealign_get_bits is also now void (rather than
returning the next byte to read). The patch defines
dav1d_get_bits_pos, which returns the current bit position. This feels
a little easier to reason about.

We also add a new check to make sure that we haven't fallen off the
end of the OBU. This can happen when a byte buffer contains more than
one OBU: the GetBits might not have got to EOF, but we might now be
half-way through the next OBU.

--- a/src/getbits.c
+++ b/src/getbits.c
@@ -126,8 +126,16 @@
     return (int) get_bits_subexp_u(c, ref + (1 << n), 2 << n) - (1 << n);
 }
 
-const uint8_t *dav1d_flush_get_bits(GetBits *c) {
+void dav1d_bytealign_get_bits(GetBits *c) {
+    // bits_left is never more than 7, because it is only incremented
+    // by refill(), called by dav1d_get_bits and that never reads more
+    // than 7 bits more than it needs.
+    //
+    // If this wasn't true, we would need to work out how many bits to
+    // discard (bits_left % 8), subtract that from bits_left and then
+    // shift state right by that amount.
+    assert(c->bits_left <= 7);
+
     c->bits_left = 0;
     c->state = 0;
-    return c->ptr;
 }
--- a/src/getbits.h
+++ b/src/getbits.h
@@ -46,6 +46,13 @@
 unsigned dav1d_get_uniform(GetBits *c, unsigned max);
 unsigned dav1d_get_vlc(GetBits *c);
 int dav1d_get_bits_subexp(GetBits *c, int ref, unsigned n);
-const uint8_t *dav1d_flush_get_bits(GetBits *c);
+
+// Discard bits from the buffer until we're next byte-aligned.
+void dav1d_bytealign_get_bits(GetBits *c);
+
+// Return the current bit position relative to the start of the buffer.
+static inline unsigned dav1d_get_bits_pos(const GetBits *c) {
+    return (c->ptr - c->ptr_start) * 8 - c->bits_left;
+}
 
 #endif /* __DAV1D_SRC_GETBITS_H__ */
--- a/src/obu.c
+++ b/src/obu.c
@@ -46,15 +46,17 @@
 static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
                          Av1SequenceHeader *const hdr)
 {
-    const uint8_t *const init_ptr = gb->ptr;
-
 #define DEBUG_SEQ_HDR 0
 
+#if DEBUG_SEQ_HDR
+    const unsigned init_bit_pos = dav1d_get_bits_pos(gb);
+#endif
+
     hdr->profile = dav1d_get_bits(gb, 3);
     if (hdr->profile > 2) goto error;
 #if DEBUG_SEQ_HDR
     printf("SEQHDR: post-profile: off=%ld\n",
-           (gb->ptr - init_ptr) * 8 - gb->bits_left);
+           dav1d_get_bits_pos(gb) - init_bit_pos);
 #endif
 
     hdr->still_picture = dav1d_get_bits(gb, 1);
@@ -62,7 +64,7 @@
     if (hdr->reduced_still_picture_header && !hdr->still_picture) goto error;
 #if DEBUG_SEQ_HDR
     printf("SEQHDR: post-stillpicture_flags: off=%ld\n",
-           (gb->ptr - init_ptr) * 8 - gb->bits_left);
+           dav1d_get_bits_pos(gb) - init_bit_pos);
 #endif
 
     if (hdr->reduced_still_picture_header) {
@@ -97,7 +99,7 @@
         }
 #if DEBUG_SEQ_HDR
         printf("SEQHDR: post-timinginfo: off=%ld\n",
-               (gb->ptr - init_ptr) * 8 - gb->bits_left);
+               dav1d_get_bits_pos(gb) - init_bit_pos);
 #endif
 
         hdr->display_model_info_present = dav1d_get_bits(gb, 1);
@@ -126,7 +128,7 @@
         }
 #if DEBUG_SEQ_HDR
         printf("SEQHDR: post-operating-points: off=%ld\n",
-               (gb->ptr - init_ptr) * 8 - gb->bits_left);
+               dav1d_get_bits_pos(gb) - init_bit_pos);
 #endif
     }
 
@@ -136,7 +138,7 @@
     hdr->max_height = dav1d_get_bits(gb, hdr->height_n_bits) + 1;
 #if DEBUG_SEQ_HDR
     printf("SEQHDR: post-size: off=%ld\n",
-           (gb->ptr - init_ptr) * 8 - gb->bits_left);
+           dav1d_get_bits_pos(gb) - init_bit_pos);
 #endif
     hdr->frame_id_numbers_present =
         hdr->reduced_still_picture_header ? 0 : dav1d_get_bits(gb, 1);
@@ -146,7 +148,7 @@
     }
 #if DEBUG_SEQ_HDR
     printf("SEQHDR: post-frame-id-numbers-present: off=%ld\n",
-           (gb->ptr - init_ptr) * 8 - gb->bits_left);
+           dav1d_get_bits_pos(gb) - init_bit_pos);
 #endif
 
     hdr->sb128 = dav1d_get_bits(gb, 1);
@@ -180,7 +182,7 @@
         hdr->screen_content_tools = dav1d_get_bits(gb, 1) ? ADAPTIVE : dav1d_get_bits(gb, 1);
     #if DEBUG_SEQ_HDR
         printf("SEQHDR: post-screentools: off=%ld\n",
-               (gb->ptr - init_ptr) * 8 - gb->bits_left);
+               dav1d_get_bits_pos(gb) - init_bit_pos);
     #endif
         hdr->force_integer_mv = hdr->screen_content_tools ?
                                 dav1d_get_bits(gb, 1) ? ADAPTIVE : dav1d_get_bits(gb, 1) : 2;
@@ -192,7 +194,7 @@
     hdr->restoration = dav1d_get_bits(gb, 1);
 #if DEBUG_SEQ_HDR
     printf("SEQHDR: post-featurebits: off=%ld\n",
-           (gb->ptr - init_ptr) * 8 - gb->bits_left);
+           dav1d_get_bits_pos(gb) - init_bit_pos);
 #endif
 
     const int hbd = dav1d_get_bits(gb, 1);
@@ -243,19 +245,23 @@
     }
 #if DEBUG_SEQ_HDR
     printf("SEQHDR: post-colorinfo: off=%ld\n",
-           (gb->ptr - init_ptr) * 8 - gb->bits_left);
+           dav1d_get_bits_pos(gb) - init_bit_pos);
 #endif
 
     hdr->film_grain_present = dav1d_get_bits(gb, 1);
 #if DEBUG_SEQ_HDR
     printf("SEQHDR: post-filmgrain: off=%ld\n",
-           (gb->ptr - init_ptr) * 8 - gb->bits_left);
+           dav1d_get_bits_pos(gb) - init_bit_pos);
 #endif
 
     dav1d_get_bits(gb, 1); // dummy bit
 
-    return dav1d_flush_get_bits(gb) - init_ptr;
+    // We needn't bother flushing the OBU here: we'll check we didn't
+    // overrun in the caller and will then discard gb, so there's no
+    // point in setting its position properly.
 
+    return 0;
+
 error:
     fprintf(stderr, "Error parsing sequence header\n");
     return -EINVAL;
@@ -313,16 +319,16 @@
     .ref_delta = { 1, 0, 0, 0, -1, 0, -1, -1 },
 };
 
-static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb,
-                           const int have_trailing_bit)
-{
+static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb) {
+#define DEBUG_FRAME_HDR 0
+
+#if DEBUG_FRAME_HDR
     const uint8_t *const init_ptr = gb->ptr;
+#endif
     const Av1SequenceHeader *const seqhdr = &c->seq_hdr;
     Av1FrameHeader *const hdr = &c->frame_hdr;
     int res;
 
-#define DEBUG_FRAME_HDR 0
-
     hdr->show_existing_frame =
         !seqhdr->reduced_still_picture_header && dav1d_get_bits(gb, 1);
 #if DEBUG_FRAME_HDR
@@ -335,7 +341,7 @@
             hdr->frame_presentation_delay = dav1d_get_bits(gb, seqhdr->frame_presentation_delay_length);
         if (seqhdr->frame_id_numbers_present)
             hdr->frame_id = dav1d_get_bits(gb, seqhdr->frame_id_n_bits);
-        goto end;
+        return 0;
     }
 
     hdr->frame_type = seqhdr->reduced_still_picture_header ? DAV1D_FRAME_TYPE_KEY : dav1d_get_bits(gb, 2);
@@ -976,21 +982,14 @@
            (gb->ptr - init_ptr) * 8 - gb->bits_left);
 #endif
 
-end:
+    return 0;
 
-    if (have_trailing_bit)
-        dav1d_get_bits(gb, 1); // dummy bit
-
-    return dav1d_flush_get_bits(gb) - init_ptr;
-
 error:
     fprintf(stderr, "Error parsing frame header\n");
     return -EINVAL;
 }
 
-static int parse_tile_hdr(Dav1dContext *const c, GetBits *const gb) {
-    const uint8_t *const init_ptr = gb->ptr;
-
+static void parse_tile_hdr(Dav1dContext *const c, GetBits *const gb) {
     int have_tile_pos = 0;
     const int n_tiles = c->frame_hdr.tiling.cols * c->frame_hdr.tiling.rows;
     if (n_tiles > 1)
@@ -1005,8 +1004,31 @@
         c->tile[c->n_tile_data].start = 0;
         c->tile[c->n_tile_data].end = n_tiles - 1;
     }
+}
 
-    return dav1d_flush_get_bits(gb) - init_ptr;
+// Check that we haven't read more than obu_len bytes from the buffer
+// since init_bit_pos.
+static int
+check_for_overrun(GetBits *const gb, unsigned init_bit_pos, unsigned obu_len)
+{
+    // Make sure we haven't actually read past the end of the gb buffer
+    if (gb->error) {
+        fprintf(stderr, "Overrun in OBU bit buffer\n");
+        return 1;
+    }
+
+    unsigned pos = dav1d_get_bits_pos(gb);
+
+    // We assume that init_bit_pos was the bit position of the buffer
+    // at some point in the past, so cannot be smaller than pos.
+    assert (init_bit_pos <= pos);
+
+    if (pos - init_bit_pos > 8 * obu_len) {
+        fprintf(stderr, "Overrun in OBU bit buffer into next OBU\n");
+        return 1;
+    }
+
+    return 0;
 }
 
 int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) {
@@ -1041,10 +1063,24 @@
     } while (more);
     if (gb.error) goto error;
 
-    unsigned off = dav1d_flush_get_bits(&gb) - in->data;
-    const unsigned init_off = off;
-    if (len > in->sz - off) goto error;
+    const unsigned init_bit_pos = dav1d_get_bits_pos(&gb);
+    const unsigned init_byte_pos = init_bit_pos >> 3;
+    const unsigned pkt_bytelen = init_byte_pos + len;
 
+    // We must have read a whole number of bytes at this point (1 byte
+    // for the header and whole bytes at a time when reading the
+    // leb128 length field).
+    assert(init_bit_pos & 7 == 0);
+
+    // We also know that we haven't tried to read more than in->sz
+    // bytes yet (otherwise the error flag would have been set by the
+    // code in getbits.c)
+    assert(in->sz >= init_byte_pos);
+
+    // Make sure that there are enough bits left in the buffer for the
+    // rest of the OBU.
+    if (len > in->sz - init_byte_pos) goto error;
+
     switch (type) {
     case OBU_SEQ_HDR: {
         Av1SequenceHeader hdr, *const hdr_ptr = c->have_seq_hdr ? &hdr : &c->seq_hdr;
@@ -1052,8 +1088,8 @@
         c->have_frame_hdr = 0;
         if ((res = parse_seq_hdr(c, &gb, hdr_ptr)) < 0)
             return res;
-        if ((unsigned)res != len)
-            goto error;
+        if (check_for_overrun(&gb, init_bit_pos, len))
+            return -EINVAL;
         if (!c->have_frame_hdr || memcmp(&hdr, &c->seq_hdr, sizeof(hdr))) {
             for (int i = 0; i < 8; i++) {
                 if (c->refs[i].p.p.data[0])
@@ -1076,7 +1112,7 @@
     case OBU_FRAME_HDR:
         c->have_frame_hdr = 0;
         if (!c->have_seq_hdr) goto error;
-        if ((res = parse_frame_hdr(c, &gb, type != OBU_FRAME)) < 0)
+        if ((res = parse_frame_hdr(c, &gb)) < 0)
             return res;
         c->have_frame_hdr = 1;
         for (int n = 0; n < c->n_tile_data; n++)
@@ -1083,22 +1119,41 @@
             dav1d_data_unref(&c->tile[n].data);
         c->n_tile_data = 0;
         c->n_tiles = 0;
-        if (type != OBU_FRAME) break;
+        if (type != OBU_FRAME) {
+            // This is actually a frame header OBU so read the
+            // trailing bit and check for overrun.
+            dav1d_get_bits(&gb, 1);
+            if (check_for_overrun(&gb, init_bit_pos, len))
+                return -EINVAL;
+
+            break;
+        }
+        // OBU_FRAMEs shouldn't be signalled with show_existing_frame
         if (c->frame_hdr.show_existing_frame) goto error;
-        off += res;
+
+        // This is the frame header at the start of a frame OBU.
+        // There's no trailing bit at the end to skip, but we do need
+        // to align to the next byte.
+        dav1d_bytealign_get_bits(&gb);
         // fall-through
-    case OBU_TILE_GRP:
+    case OBU_TILE_GRP: {
         if (!c->have_frame_hdr) goto error;
         if (c->n_tile_data >= 256) goto error;
-        if ((res = parse_tile_hdr(c, &gb)) < 0)
-            return res;
-        off += res;
-        if (off > len + init_off)
-            goto error;
+        parse_tile_hdr(c, &gb);
+        // Align to the next byte boundary and check for overrun.
+        dav1d_bytealign_get_bits(&gb);
+        if (check_for_overrun(&gb, init_bit_pos, len))
+            return -EINVAL;
+        // The current bit position is a multiple of 8 (because we
+        // just aligned it) and less than 8*pkt_bytelen because
+        // otherwise the overrun check would have fired.
+        const unsigned bit_pos = dav1d_get_bits_pos(&gb);
+        assert(bit_pos & 7 == 0);
+        assert(pkt_bytelen > (bit_pos >> 3));
         dav1d_ref_inc(in->ref);
         c->tile[c->n_tile_data].data.ref = in->ref;
-        c->tile[c->n_tile_data].data.data = in->data + off;
-        c->tile[c->n_tile_data].data.sz = len + init_off - off;
+        c->tile[c->n_tile_data].data.data = in->data + (bit_pos >> 3);
+        c->tile[c->n_tile_data].data.sz = pkt_bytelen - (bit_pos >> 3);
         // ensure tile groups are in order and sane, see 6.10.1
         if (c->tile[c->n_tile_data].start > c->tile[c->n_tile_data].end ||
             c->tile[c->n_tile_data].start != c->n_tiles)
@@ -1113,6 +1168,7 @@
                           c->tile[c->n_tile_data].start;
         c->n_tile_data++;
         break;
+    }
     case OBU_PADDING:
     case OBU_TD:
     case OBU_METADATA:
@@ -1192,7 +1248,7 @@
         }
     }
 
-    return len + init_off;
+    return len + init_byte_pos;
 
 error:
     fprintf(stderr, "Error parsing OBU data\n");