shithub: dav1d

Download patch

ref: 0276455de73c4a520df12a3d6f80574b988d219a
parent: beda6e0d1c37f06e4e03f7ebe13311bd8b18245e
author: Henrik Gramner <[email protected]>
date: Thu Jun 27 10:32:26 EDT 2019

Consolidate scratch buffers

Also eliminate some pointer chasing by allocating tile context buffers
as part of the struct instead of having the struct contain pointers to
separately allocated buffers.

--- a/src/decode.c
+++ b/src/decode.c
@@ -430,7 +430,7 @@
     // parse new entries
     uint16_t *const pal = f->frame_thread.pass ?
         f->frame_thread.pal[((t->by >> 1) + (t->bx & 1)) * (f->b4_stride >> 1) +
-                            ((t->bx >> 1) + (t->by & 1))][pl] : t->pal[pl];
+                            ((t->bx >> 1) + (t->by & 1))][pl] : t->scratch.pal[pl];
     if (i < pal_sz) {
         int prev = pal[i++] = dav1d_msac_decode_bools(&ts->msac, f->cur.p.bpc);
 
@@ -486,7 +486,7 @@
     const Dav1dFrameContext *const f = t->f;
     uint16_t *const pal = f->frame_thread.pass ?
         f->frame_thread.pal[((t->by >> 1) + (t->bx & 1)) * (f->b4_stride >> 1) +
-                            ((t->bx >> 1) + (t->by & 1))][2] : t->pal[2];
+                            ((t->bx >> 1) + (t->by & 1))][2] : t->scratch.pal[2];
     if (dav1d_msac_decode_bool_equi(&ts->msac)) {
         const int bits = f->cur.p.bpc - 4 +
                          dav1d_msac_decode_bools(&ts->msac, 2);
@@ -588,9 +588,10 @@
     pal_idx[0] = dav1d_msac_decode_uniform(&ts->msac, b->pal_sz[pl]);
     uint16_t (*const color_map_cdf)[8 + 1] =
         ts->cdf.m.color_map[pl][b->pal_sz[pl] - 2];
+    uint8_t (*const order)[8] = t->scratch.pal_order;
+    uint8_t *const ctx = t->scratch.pal_ctx;
     for (int i = 1; i < 4 * (w4 + h4) - 1; i++) {
         // top/left-to-bottom/right diagonals ("wave-front")
-        uint8_t order[64][8], ctx[64];
         const int first = imin(i, w4 * 4 - 1);
         const int last = imax(0, i - h4 * 4 + 1);
         order_palette(pal_idx, stride, i, first, last, order, ctx);
@@ -1211,7 +1212,7 @@
         if (b->pal_sz[0]) {
             uint16_t *const pal = f->frame_thread.pass ?
                 f->frame_thread.pal[((t->by >> 1) + (t->bx & 1)) * (f->b4_stride >> 1) +
-                                    ((t->bx >> 1) + (t->by & 1))][0] : t->pal[0];
+                                    ((t->bx >> 1) + (t->by & 1))][0] : t->scratch.pal[0];
             for (int x = 0; x < bw4; x++)
                 memcpy(t->al_pal[0][bx4 + x][0], pal, 16);
             for (int y = 0; y < bh4; y++)
@@ -1223,15 +1224,18 @@
                 case_set(cbh4, l., 1, cby4);
                 case_set(cbw4, a->, 0, cbx4);
 #undef set_ctx
-            if (b->pal_sz[1]) for (int pl = 1; pl < 3; pl++) {
-                uint16_t *const pal = f->frame_thread.pass ?
-                    f->frame_thread.pal[((t->by >> 1) + (t->bx & 1)) * (f->b4_stride >> 1) +
-                                        ((t->bx >> 1) + (t->by & 1))][pl] : t->pal[pl];
+            if (b->pal_sz[1]) {
+                const uint16_t (*const pal)[8] = f->frame_thread.pass ?
+                    f->frame_thread.pal[((t->by >> 1) + (t->bx & 1)) *
+                    (f->b4_stride >> 1) + ((t->bx >> 1) + (t->by & 1))] :
+                    t->scratch.pal;
                 // see aomedia bug 2183 for why we use luma coordinates here
-                for (int x = 0; x < bw4; x++)
-                    memcpy(t->al_pal[0][bx4 + x][pl], pal, 16);
-                for (int y = 0; y < bh4; y++)
-                    memcpy(t->al_pal[1][by4 + y][pl], pal, 16);
+                for (int pl = 1; pl <= 2; pl++) {
+                    for (int x = 0; x < bw4; x++)
+                        memcpy(t->al_pal[0][bx4 + x][pl], pal[pl], 16);
+                    for (int y = 0; y < bh4; y++)
+                        memcpy(t->al_pal[1][by4 + y][pl], pal[pl], 16);
+                }
             }
         }
         if ((f->frame_hdr->frame_type & 1) || f->frame_hdr->allow_intrabc) {
--- a/src/internal.h
+++ b/src/internal.h
@@ -275,24 +275,54 @@
     Dav1dTileState *ts;
     int bx, by;
     BlockContext l, *a;
-    coef *cf;
-    pixel *emu_edge; // stride=192 for non-SVC, or 320 for SVC
+    ALIGN(union, 32) {
+        int16_t cf_8bpc [32 * 32];
+        int32_t cf_16bpc[32 * 32];
+    };
     // FIXME types can be changed to pixel (and dynamically allocated)
     // which would make copy/assign operations slightly faster?
     uint16_t al_pal[2 /* a/l */][32 /* bx/y4 */][3 /* plane */][8 /* palette_idx */];
-    ALIGN(uint16_t pal[3 /* plane */][8 /* palette_idx */], 16);
     uint8_t pal_sz_uv[2 /* a/l */][32 /* bx4/by4 */];
     uint8_t txtp_map[32 * 32]; // inter-only
-    Dav1dWarpedMotionParams warpmv;
-    union {
-        void *mem;
-        uint8_t *pal_idx;
-        int16_t *ac;
-        pixel *interintra, *lap;
-        int16_t *compinter;
+    ALIGN(union, 32) {
+        struct {
+            union {
+                uint8_t  lap_8bpc [128 * 32];
+                uint16_t lap_16bpc[128 * 32];
+                struct {
+                    int16_t compinter[2][128 * 128];
+                    uint8_t seg_mask[128 * 128];
+                };
+            };
+            union {
+                // stride=192 for non-SVC, or 320 for SVC
+                uint8_t  emu_edge_8bpc [320 * (256 + 7)];
+                uint16_t emu_edge_16bpc[320 * (256 + 7)];
+            };
+        };
+        struct {
+            uint8_t interintra_8bpc[64 * 64];
+            uint8_t edge_8bpc[257];
+        };
+        struct {
+            uint16_t interintra_16bpc[64 * 64];
+            uint16_t edge_16bpc[257];
+        };
+        struct {
+            uint8_t pal_idx[2 * 64 * 64];
+            union {
+                struct {
+                    uint8_t pal_order[64][8];
+                    uint8_t pal_ctx[64];
+                };
+                uint8_t levels[36 * 36];
+            };
+            uint16_t pal[3 /* plane */][8 /* palette_idx */];
+        };
+        int16_t ac[32 * 32];
     } scratch;
-    ALIGN(uint8_t scratch_seg_mask[128 * 128], 32);
 
+    Dav1dWarpedMotionParams warpmv;
     Av1Filter *lf_mask;
     int8_t *cur_sb_cdef_idx_ptr;
     // for chroma sub8x8, we need to know the filter for all 4 subblocks in
--- a/src/lib.c
+++ b/src/lib.c
@@ -152,14 +152,7 @@
         for (int m = 0; m < s->n_tile_threads; m++) {
             Dav1dTileContext *const t = &f->tc[m];
             t->f = f;
-            t->cf = dav1d_alloc_aligned(32 * 32 * sizeof(int32_t), 32);
-            if (!t->cf) goto error;
-            t->scratch.mem = dav1d_alloc_aligned(128 * 128 * 4, 32);
-            if (!t->scratch.mem) goto error;
-            memset(t->cf, 0, 32 * 32 * sizeof(int32_t));
-            t->emu_edge =
-                dav1d_alloc_aligned(320 * (256 + 7) * sizeof(uint16_t), 32);
-            if (!t->emu_edge) goto error;
+            memset(t->cf_16bpc, 0, sizeof(t->cf_16bpc));
             if (f->n_tc > 1) {
                 if (pthread_mutex_init(&t->tile_thread.td.lock, NULL)) goto error;
                 if (pthread_cond_init(&t->tile_thread.td.cond, NULL)) {
@@ -500,12 +493,6 @@
             pthread_cond_destroy(&f->tile_thread.cond);
             pthread_cond_destroy(&f->tile_thread.icond);
             freep(&f->tile_thread.task_idx_to_sby_and_tile_idx);
-        }
-        for (int m = 0; f->tc && m < f->n_tc; m++) {
-            Dav1dTileContext *const t = &f->tc[m];
-            dav1d_free_aligned(t->cf);
-            dav1d_free_aligned(t->scratch.mem);
-            dav1d_free_aligned(t->emu_edge);
         }
         for (int m = 0; f->ts && m < f->n_ts; m++) {
             Dav1dTileState *const ts = &f->ts[m];
--- a/src/recon_tmpl.c
+++ b/src/recon_tmpl.c
@@ -165,7 +165,7 @@
     int dc_tok;
 
     if (eob) {
-        ALIGN_STK_16(uint8_t, levels, 36 * 36,);
+        uint8_t *const levels = t->scratch.levels;
         const int sw = imin(t_dim->w, 8), sh = imin(t_dim->h, 8);
         const ptrdiff_t stride = 4 * (sh + 1);
         memset(levels, 0, stride * 4 * (sw + 1));
@@ -394,7 +394,7 @@
             ts->frame_thread.cf += imin(t_dim->w, 8) * imin(t_dim->h, 8) * 16;
             cbi = &f->frame_thread.cbi[t->by * f->b4_stride + t->bx];
         } else {
-            cf = t->cf;
+            cf = bitfn(t->cf);
         }
         if (f->frame_thread.pass != 2) {
             eob = decode_coefs(t, &t->a->lcoef[bx4], &t->l.lcoef[by4],
@@ -613,11 +613,12 @@
             dx + bw4 * h_mul + !!mx * 4 > w ||
             dy + bh4 * v_mul + !!my * 4 > h)
         {
+            pixel *const emu_edge_buf = bitfn(t->scratch.emu_edge);
             f->dsp->mc.emu_edge(bw4 * h_mul + !!mx * 7, bh4 * v_mul + !!my * 7,
                                 w, h, dx - !!mx * 3, dy - !!my * 3,
-                                t->emu_edge, 192 * sizeof(pixel),
+                                emu_edge_buf, 192 * sizeof(pixel),
                                 refp->p.data[pl], ref_stride);
-            ref = &t->emu_edge[192 * !!my * 3 + !!mx * 3];
+            ref = &emu_edge_buf[192 * !!my * 3 + !!mx * 3];
             ref_stride = 192 * sizeof(pixel);
         } else {
             ref = ((pixel *) refp->p.data[pl]) + PXSTRIDE(ref_stride) * dy + dx;
@@ -663,11 +664,12 @@
         const int w = (refp->p.p.w + ss_hor) >> ss_hor;
         const int h = (refp->p.p.h + ss_ver) >> ss_ver;
         if (left < 3 || top < 3 || right + 4 > w || bottom + 4 > h) {
+            pixel *const emu_edge_buf = bitfn(t->scratch.emu_edge);
             f->dsp->mc.emu_edge(right - left + 7, bottom - top + 7,
                                 w, h, left - 3, top - 3,
-                                t->emu_edge, 320 * sizeof(pixel),
+                                emu_edge_buf, 320 * sizeof(pixel),
                                 refp->p.data[pl], ref_stride);
-            ref = &t->emu_edge[320 * 3 + 3];
+            ref = &emu_edge_buf[320 * 3 + 3];
             ref_stride = 320 * sizeof(pixel);
             if (DEBUG_BLOCK_INFO) printf("Emu\n");
         } else {
@@ -702,7 +704,7 @@
     assert(!(t->bx & 1) && !(t->by & 1));
     const Dav1dFrameContext *const f = t->f;
     const refmvs *const r = &f->mvs[t->by * f->b4_stride + t->bx];
-    pixel *const lap = t->scratch.lap;
+    pixel *const lap = bitfn(t->scratch.lap);
     const int ss_ver = !!pl && f->cur.p.layout == DAV1D_PIXEL_LAYOUT_I420;
     const int ss_hor = !!pl && f->cur.p.layout != DAV1D_PIXEL_LAYOUT_I444;
     const int h_mul = 4 >> ss_hor, v_mul = 4 >> ss_ver;
@@ -799,11 +801,12 @@
                 return -1;
             }
             if (dx < 3 || dx + 8 + 4 > width || dy < 3 || dy + 8 + 4 > height) {
+                pixel *const emu_edge_buf = bitfn(t->scratch.emu_edge);
                 f->dsp->mc.emu_edge(15, 15, width, height, dx - 3, dy - 3,
-                                    t->emu_edge, 192 * sizeof(pixel),
+                                    emu_edge_buf, 32 * sizeof(pixel),
                                     refp->p.data[pl], ref_stride);
-                ref_ptr = &t->emu_edge[192 * 3 + 3];
-                ref_stride = 192 * sizeof(pixel);
+                ref_ptr = &emu_edge_buf[32 * 3 + 3];
+                ref_stride = 32 * sizeof(pixel);
             } else {
                 ref_ptr = ((pixel *) refp->p.data[pl]) + PXSTRIDE(ref_stride) * dy + dx;
             }
@@ -842,8 +845,7 @@
     const TxfmInfo *const uv_t_dim = &dav1d_txfm_dimensions[b->uvtx];
 
     // coefficient coding
-    ALIGN_STK_32(pixel, edge_buf, 257,);
-    pixel *const edge = edge_buf + 128;
+    pixel *const edge = bitfn(t->scratch.edge) + 128;
     const int cbw4 = (bw4 + ss_hor) >> ss_hor, cbh4 = (bh4 + ss_ver) >> ss_ver;
 
     const int intra_edge_filter_flag = f->seq_hdr->intra_edge_filter << 10;
@@ -862,7 +864,7 @@
                 }
                 const uint16_t *const pal = f->frame_thread.pass ?
                     f->frame_thread.pal[((t->by >> 1) + (t->bx & 1)) * (f->b4_stride >> 1) +
-                                        ((t->bx >> 1) + (t->by & 1))][0] : t->pal[0];
+                                        ((t->bx >> 1) + (t->by & 1))][0] : t->scratch.pal[0];
                 f->dsp->ipred.pal_pred(dst, f->cur.stride[0], pal,
                                        pal_idx, bw4 * 4, bh4 * 4);
                 if (DEBUG_BLOCK_INFO && DEBUG_B_PIXELS)
@@ -947,7 +949,7 @@
                             txtp = cbi->txtp[0];
                         } else {
                             uint8_t cf_ctx;
-                            cf = t->cf;
+                            cf = bitfn(t->cf);
                             eob = decode_coefs(t, &t->a->lcoef[bx4 + x],
                                                &t->l.lcoef[by4 + y], b->tx, bs,
                                                b, 1, 0, cf, &txtp, &cf_ctx);
@@ -1048,24 +1050,23 @@
             } else if (b->pal_sz[1]) {
                 ptrdiff_t uv_dstoff = 4 * ((t->bx >> ss_hor) +
                                            (t->by >> ss_ver) * PXSTRIDE(f->cur.stride[1]));
+                const uint16_t (*pal)[8];
                 const uint8_t *pal_idx;
                 if (f->frame_thread.pass) {
+                    pal = f->frame_thread.pal[((t->by >> 1) + (t->bx & 1)) * (f->b4_stride >> 1) +
+                                              ((t->bx >> 1) + (t->by & 1))];
                     pal_idx = ts->frame_thread.pal_idx;
                     ts->frame_thread.pal_idx += cbw4 * cbh4 * 16;
                 } else {
+                    pal = t->scratch.pal;
                     pal_idx = &t->scratch.pal_idx[bw4 * bh4 * 16];
                 }
-                const uint16_t *const pal_u = f->frame_thread.pass ?
-                    f->frame_thread.pal[((t->by >> 1) + (t->bx & 1)) * (f->b4_stride >> 1) +
-                                        ((t->bx >> 1) + (t->by & 1))][1] : t->pal[1];
+
                 f->dsp->ipred.pal_pred(((pixel *) f->cur.data[1]) + uv_dstoff,
-                                       f->cur.stride[1], pal_u,
+                                       f->cur.stride[1], pal[1],
                                        pal_idx, cbw4 * 4, cbh4 * 4);
-                const uint16_t *const pal_v = f->frame_thread.pass ?
-                    f->frame_thread.pal[((t->by >> 1) + (t->bx & 1)) * (f->b4_stride >> 1) +
-                                        ((t->bx >> 1) + (t->by & 1))][2] : t->pal[2];
                 f->dsp->ipred.pal_pred(((pixel *) f->cur.data[2]) + uv_dstoff,
-                                       f->cur.stride[1], pal_v,
+                                       f->cur.stride[1], pal[2],
                                        pal_idx, cbw4 * 4, cbh4 * 4);
                 if (DEBUG_BLOCK_INFO && DEBUG_B_PIXELS) {
                     hex_dump(((pixel *) f->cur.data[1]) + uv_dstoff,
@@ -1170,7 +1171,7 @@
                                 txtp = cbi->txtp[pl + 1];
                             } else {
                                 uint8_t cf_ctx;
-                                cf = t->cf;
+                                cf = bitfn(t->cf);
                                 eob = decode_coefs(t, &t->a->ccoef[pl][cbx4 + x],
                                                    &t->l.ccoef[pl][cby4 + y],
                                                    b->uvtx, bs, b, 1, 1 + pl, cf,
@@ -1281,11 +1282,10 @@
             }
         }
         if (b->interintra_type) {
-            ALIGN_STK_32(pixel, tl_edge_buf, 65,);
-            pixel *const tl_edge = tl_edge_buf + 32;
+            pixel *const tl_edge = bitfn(t->scratch.edge) + 32;
             enum IntraPredMode m = b->interintra_mode == II_SMOOTH_PRED ?
                                    SMOOTH_PRED : b->interintra_mode;
-            pixel *const tmp = t->scratch.interintra;
+            pixel *const tmp = bitfn(t->scratch.interintra);
             int angle = 0;
             const pixel *top_sb_edge = NULL;
             if (!(t->by & (f->sb_step - 1))) {
@@ -1415,9 +1415,8 @@
                          dav1d_wedge_masks[bs][chr_layout_idx][0][b->wedge_idx];
 
                 for (int pl = 0; pl < 2; pl++) {
-                    pixel *const tmp = t->scratch.interintra;
-                    ALIGN_STK_32(pixel, tl_edge_px, 65,);
-                    pixel *const tl_edge = &tl_edge_px[32];
+                    pixel *const tmp = bitfn(t->scratch.interintra);
+                    pixel *const tl_edge = bitfn(t->scratch.edge) + 32;
                     enum IntraPredMode m =
                         b->interintra_mode == II_SMOOTH_PRED ?
                         SMOOTH_PRED : b->interintra_mode;
@@ -1455,9 +1454,9 @@
     } else {
         const enum Filter2d filter_2d = b->filter2d;
         // Maximum super block size is 128x128
-        int16_t (*tmp)[128 * 128] = (int16_t (*)[128 * 128]) t->scratch.compinter;
+        int16_t (*tmp)[128 * 128] = t->scratch.compinter;
         int jnt_weight;
-        uint8_t *const seg_mask = t->scratch_seg_mask;
+        uint8_t *const seg_mask = t->scratch.seg_mask;
         const uint8_t *mask;
 
         for (int i = 0; i < 2; i++) {
@@ -1619,7 +1618,7 @@
                             txtp = cbi->txtp[1 + pl];
                         } else {
                             uint8_t cf_ctx;
-                            cf = t->cf;
+                            cf = bitfn(t->cf);
                             txtp = t->txtp_map[(by4 + (y << ss_ver)) * 32 +
                                                 bx4 + (x << ss_hor)];
                             eob = decode_coefs(t, &t->a->ccoef[pl][cbx4 + x],