shithub: dav1d

Download patch

ref: 70b66ff13fc7f082e43777a8c9fa4c0a2ace685e
parent: a62c445d842d7e459062fad8468aceb8f5efaef4
author: Henrik Gramner <[email protected]>
date: Tue Aug 20 14:59:32 EDT 2019

Optimize coef ctx calculations

--- a/include/common/intops.h
+++ b/include/common/intops.h
@@ -40,6 +40,14 @@
     return a < b ? a : b;
 }
 
+static inline unsigned umax(const unsigned a, const unsigned b) {
+    return a > b ? a : b;
+}
+
+static inline unsigned umin(const unsigned a, const unsigned b) {
+    return a < b ? a : b;
+}
+
 static inline int iclip(const int v, const int min, const int max) {
     return v < min ? min : v > max ? max : v;
 }
--- a/src/env.h
+++ b/src/env.h
@@ -469,180 +469,6 @@
     }
 }
 
-static inline int get_coef_skip_ctx(const TxfmInfo *const t_dim,
-                                    const enum BlockSize bs,
-                                    const uint8_t *const a,
-                                    const uint8_t *const l,
-                                    const int chroma,
-                                    const enum Dav1dPixelLayout layout)
-{
-    const uint8_t *const b_dim = dav1d_block_dimensions[bs];
-
-    if (chroma) {
-        const int ss_ver = layout == DAV1D_PIXEL_LAYOUT_I420;
-        const int ss_hor = layout != DAV1D_PIXEL_LAYOUT_I444;
-        const int not_one_blk = b_dim[2] - (!!b_dim[2] && ss_hor) > t_dim->lw ||
-                                b_dim[3] - (!!b_dim[3] && ss_ver) > t_dim->lh;
-        int ca, cl;
-
-#define MERGE_CTX(dir, type, mask) \
-        c##dir = !!((*(const type *) dir) & mask); \
-        break
-        switch (t_dim->lw) {
-        case TX_4X4:   MERGE_CTX(a, uint8_t,  0x3F);
-        case TX_8X8:   MERGE_CTX(a, uint16_t, 0x3F3F);
-        case TX_16X16: MERGE_CTX(a, uint32_t, 0x3F3F3F3FU);
-        case TX_32X32: MERGE_CTX(a, uint64_t, 0x3F3F3F3F3F3F3F3FULL);
-        default: abort();
-        }
-        switch (t_dim->lh) {
-        case TX_4X4:   MERGE_CTX(l, uint8_t,  0x3F);
-        case TX_8X8:   MERGE_CTX(l, uint16_t, 0x3F3F);
-        case TX_16X16: MERGE_CTX(l, uint32_t, 0x3F3F3F3FU);
-        case TX_32X32: MERGE_CTX(l, uint64_t, 0x3F3F3F3F3F3F3F3FULL);
-        default: abort();
-        }
-#undef MERGE_CTX
-
-        return 7 + not_one_blk * 3 + ca + cl;
-    } else if (b_dim[2] == t_dim->lw && b_dim[3] == t_dim->lh) {
-        return 0;
-    } else {
-        static const uint8_t skip_contexts[5][5] = {
-            { 1, 2, 2, 2, 3 },
-            { 1, 4, 4, 4, 5 },
-            { 1, 4, 4, 4, 5 },
-            { 1, 4, 4, 4, 5 },
-            { 1, 4, 4, 4, 6 }
-        };
-        uint64_t la, ll;
-
-#define MERGE_CTX(dir, type, tx) do { \
-            l##dir = *(const type *) dir; \
-            if (tx == TX_64X64) \
-                l##dir |= *(const type *) &dir[sizeof(type)]; \
-            if (tx >= TX_32X32) l##dir |= l##dir >> 32; \
-            if (tx >= TX_16X16) l##dir |= l##dir >> 16; \
-            if (tx >= TX_8X8)   l##dir |= l##dir >> 8; \
-            l##dir &= 0x3F; \
-        } while (0); \
-        break
-        switch (t_dim->lw) {
-        case TX_4X4:   MERGE_CTX(a, uint8_t,  TX_4X4);
-        case TX_8X8:   MERGE_CTX(a, uint16_t, TX_8X8);
-        case TX_16X16: MERGE_CTX(a, uint32_t, TX_16X16);
-        case TX_32X32: MERGE_CTX(a, uint64_t, TX_32X32);
-        case TX_64X64: MERGE_CTX(a, uint64_t, TX_64X64);
-        }
-        switch (t_dim->lh) {
-        case TX_4X4:   MERGE_CTX(l, uint8_t,  TX_4X4);
-        case TX_8X8:   MERGE_CTX(l, uint16_t, TX_8X8);
-        case TX_16X16: MERGE_CTX(l, uint32_t, TX_16X16);
-        case TX_32X32: MERGE_CTX(l, uint64_t, TX_32X32);
-        case TX_64X64: MERGE_CTX(l, uint64_t, TX_64X64);
-        }
-#undef MERGE_CTX
-
-        const int max = imin((int) (la | ll), 4);
-        const int min = imin(imin((int) la, (int) ll), 4);
-
-        return skip_contexts[min][max];
-    }
-}
-
-static inline int get_coef_nz_ctx(uint8_t *const levels,
-                                  const enum RectTxfmSize tx,
-                                  const enum TxClass tx_class,
-                                  const int x, const int y,
-                                  const ptrdiff_t stride)
-{
-    static const uint8_t offsets[3][5][2 /* x, y */] = {
-        [TX_CLASS_2D] = {
-            { 0, 1 }, { 1, 0 }, { 2, 0 }, { 0, 2 }, { 1, 1 }
-        }, [TX_CLASS_V] = {
-            { 0, 1 }, { 1, 0 }, { 0, 2 }, { 0, 3 }, { 0, 4 }
-        }, [TX_CLASS_H] = {
-            { 0, 1 }, { 1, 0 }, { 2, 0 }, { 3, 0 }, { 4, 0 }
-        }
-    };
-    const uint8_t (*const off)[2] = offsets[tx_class];
-    int mag = 0;
-    for (int i = 0; i < 5; i++)
-        mag += imin(levels[(x + off[i][0]) * stride + (y + off[i][1])], 3);
-    const int ctx = imin((mag + 1) >> 1, 4);
-    if (tx_class == TX_CLASS_2D) {
-        return dav1d_nz_map_ctx_offset[tx][imin(y, 4)][imin(x, 4)] + ctx;
-    } else {
-        return 26 + imin((tx_class == TX_CLASS_V) ? y : x, 2) * 5 + ctx;
-    }
-}
-
-static inline int get_dc_sign_ctx(const TxfmInfo *const t_dim,
-                                  const uint8_t *const a,
-                                  const uint8_t *const l)
-{
-    uint64_t sa, sl;
-
-#define MERGE_CTX(dir, type, tx, mask) do { \
-        s##dir = ((*(const type *) dir) >> 6) & mask; \
-        if (tx == TX_64X64) \
-            s##dir += ((*(const type *) &dir[sizeof(type)]) >> 6) & mask; \
-        if (tx >= TX_32X32) s##dir += s##dir >> 32; \
-        if (tx >= TX_16X16) s##dir += s##dir >> 16; \
-        if (tx >= TX_8X8)   s##dir += s##dir >> 8; \
-    } while (0); \
-    break
-    switch (t_dim->lw) {
-    case TX_4X4:   MERGE_CTX(a, uint8_t,  TX_4X4,   0x03);
-    case TX_8X8:   MERGE_CTX(a, uint16_t, TX_8X8,   0x0303);
-    case TX_16X16: MERGE_CTX(a, uint32_t, TX_16X16, 0x03030303U);
-    case TX_32X32: MERGE_CTX(a, uint64_t, TX_32X32, 0x0303030303030303ULL);
-    case TX_64X64: MERGE_CTX(a, uint64_t, TX_64X64, 0x0303030303030303ULL);
-    }
-    switch (t_dim->lh) {
-    case TX_4X4:   MERGE_CTX(l, uint8_t,  TX_4X4,   0x03);
-    case TX_8X8:   MERGE_CTX(l, uint16_t, TX_8X8,   0x0303);
-    case TX_16X16: MERGE_CTX(l, uint32_t, TX_16X16, 0x03030303U);
-    case TX_32X32: MERGE_CTX(l, uint64_t, TX_32X32, 0x0303030303030303ULL);
-    case TX_64X64: MERGE_CTX(l, uint64_t, TX_64X64, 0x0303030303030303ULL);
-    }
-#undef MERGE_CTX
-    const int s = ((int) ((sa + sl) & 0xFF)) - (t_dim->w + t_dim->h);
-
-    return s < 0 ? 1 : s > 0 ? 2 : 0;
-}
-
-static inline int get_br_ctx(const uint8_t *const levels,
-                             const int ac, const enum TxClass tx_class,
-                             const int x, const int y,
-                             const ptrdiff_t stride)
-{
-    int mag = 0;
-    static const uint8_t offsets_from_txclass[3][3][2] = {
-        [TX_CLASS_2D] = { { 0, 1 }, { 1, 0 }, { 1, 1 } },
-        [TX_CLASS_H]  = { { 0, 1 }, { 1, 0 }, { 0, 2 } },
-        [TX_CLASS_V]  = { { 0, 1 }, { 1, 0 }, { 2, 0 } }
-    };
-    const uint8_t (*const offsets)[2] = offsets_from_txclass[tx_class];
-    for (int i = 0; i < 3; i++)
-        mag += levels[(x + offsets[i][1]) * stride + y + offsets[i][0]];
-
-    mag = imin((mag + 1) >> 1, 6);
-    if (!ac) return mag;
-    switch (tx_class) {
-    case TX_CLASS_2D:
-        if (y < 2 && x < 2) return mag + 7;
-        break;
-    case TX_CLASS_H:
-        if (x == 0) return mag + 7;
-        break;
-    case TX_CLASS_V:
-        if (y == 0) return mag + 7;
-        break;
-    }
-    return mag + 14;
-}
-
 static inline mv get_gmv_2d(const Dav1dWarpedMotionParams *const gmv,
                             const int bx4, const int by4,
                             const int bw4, const int bh4,
--- a/src/internal.h
+++ b/src/internal.h
@@ -309,14 +309,14 @@
             uint16_t edge_16bpc[257];
         };
         struct {
-            uint8_t pal_idx[2 * 64 * 64];
             union {
+                uint8_t levels[32 * 34];
                 struct {
                     uint8_t pal_order[64][8];
                     uint8_t pal_ctx[64];
                 };
-                uint8_t levels[36 * 36];
             };
+            uint8_t pal_idx[2 * 64 * 64];
             uint16_t pal[3 /* plane */][8 /* palette_idx */];
         };
         int16_t ac[32 * 32];
--- a/src/recon_tmpl.c
+++ b/src/recon_tmpl.c
@@ -46,16 +46,273 @@
 #include "src/tables.h"
 #include "src/wedge.h"
 
-static unsigned read_golomb(MsacContext *const msac) {
+static inline unsigned read_golomb(MsacContext *const msac) {
     int len = 0;
     unsigned val = 1;
 
     while (!dav1d_msac_decode_bool_equi(msac) && len < 32) len++;
-    while (len--) val = (val << 1) | dav1d_msac_decode_bool_equi(msac);
+    while (len--) val = (val << 1) + dav1d_msac_decode_bool_equi(msac);
 
     return val - 1;
 }
 
+static inline unsigned get_skip_ctx(const TxfmInfo *const t_dim,
+                                    const enum BlockSize bs,
+                                    const uint8_t *const a,
+                                    const uint8_t *const l,
+                                    const int chroma,
+                                    const enum Dav1dPixelLayout layout)
+{
+    const uint8_t *const b_dim = dav1d_block_dimensions[bs];
+
+    if (chroma) {
+        const int ss_ver = layout == DAV1D_PIXEL_LAYOUT_I420;
+        const int ss_hor = layout != DAV1D_PIXEL_LAYOUT_I444;
+        const int not_one_blk = b_dim[2] - (!!b_dim[2] && ss_hor) > t_dim->lw ||
+                                b_dim[3] - (!!b_dim[3] && ss_ver) > t_dim->lh;
+        int ca, cl;
+
+#define MERGE_CTX(dir, type, mask) \
+        c##dir = !!((*(const type *) dir) & mask); \
+        break
+
+        switch (t_dim->lw) {
+        case TX_4X4:   MERGE_CTX(a, uint8_t,  0x3F);
+        case TX_8X8:   MERGE_CTX(a, uint16_t, 0x3F3F);
+        case TX_16X16: MERGE_CTX(a, uint32_t, 0x3F3F3F3FU);
+        case TX_32X32: MERGE_CTX(a, uint64_t, 0x3F3F3F3F3F3F3F3FULL);
+        default: assert(0);
+        }
+        switch (t_dim->lh) {
+        case TX_4X4:   MERGE_CTX(l, uint8_t,  0x3F);
+        case TX_8X8:   MERGE_CTX(l, uint16_t, 0x3F3F);
+        case TX_16X16: MERGE_CTX(l, uint32_t, 0x3F3F3F3FU);
+        case TX_32X32: MERGE_CTX(l, uint64_t, 0x3F3F3F3F3F3F3F3FULL);
+        default: assert(0);
+        }
+#undef MERGE_CTX
+
+        return 7 + not_one_blk * 3 + ca + cl;
+    } else if (b_dim[2] == t_dim->lw && b_dim[3] == t_dim->lh) {
+        return 0;
+    } else {
+        unsigned la, ll;
+
+#define MERGE_CTX(dir, type, tx) \
+        if (tx == TX_64X64) { \
+            uint64_t tmp = *(const uint64_t *) dir; \
+            tmp |= *(const uint64_t *) &dir[8]; \
+            l##dir = (unsigned) (tmp >> 32) | (unsigned) tmp; \
+        } else \
+            l##dir = *(const type *) dir; \
+        if (tx == TX_32X32) l##dir |= *(const type *) &dir[sizeof(type)]; \
+        if (tx >= TX_16X16) l##dir |= l##dir >> 16; \
+        if (tx >= TX_8X8)   l##dir |= l##dir >> 8; \
+        break
+
+        switch (t_dim->lw) {
+        case TX_4X4:   MERGE_CTX(a, uint8_t,  TX_4X4);
+        case TX_8X8:   MERGE_CTX(a, uint16_t, TX_8X8);
+        case TX_16X16: MERGE_CTX(a, uint32_t, TX_16X16);
+        case TX_32X32: MERGE_CTX(a, uint32_t, TX_32X32);
+        case TX_64X64: MERGE_CTX(a, uint32_t, TX_64X64);
+        default: assert(0);
+        }
+        switch (t_dim->lh) {
+        case TX_4X4:   MERGE_CTX(l, uint8_t,  TX_4X4);
+        case TX_8X8:   MERGE_CTX(l, uint16_t, TX_8X8);
+        case TX_16X16: MERGE_CTX(l, uint32_t, TX_16X16);
+        case TX_32X32: MERGE_CTX(l, uint32_t, TX_32X32);
+        case TX_64X64: MERGE_CTX(l, uint32_t, TX_64X64);
+        default: assert(0);
+        }
+#undef MERGE_CTX
+
+        return dav1d_skip_ctx[umin(la & 0x3F, 4)][umin(ll & 0x3F, 4)];
+    }
+}
+
+static inline unsigned get_dc_sign_ctx(const int /*enum RectTxfmSize*/ tx,
+                                       const uint8_t *const a,
+                                       const uint8_t *const l)
+{
+    uint64_t mask = 0xC0C0C0C0C0C0C0C0ULL, mul = 0x0101010101010101ULL;
+    int s;
+
+#if ARCH_X86_64 && defined(__GNUC__)
+    /* Coerce compilers into producing better code. For some reason
+     * every x86-64 compiler is awful at handling 64-bit constants. */
+    __asm__("" : "+r"(mask), "+r"(mul));
+#endif
+
+    switch(tx) {
+    case TX_4X4: {
+        int t = *(const uint8_t *) a >> 6;
+        t    += *(const uint8_t *) l >> 6;
+        s = t - 1 - 1;
+        break;
+    }
+    case TX_8X8: {
+        uint32_t t = *(const uint16_t *) a & (uint32_t) mask;
+        t         += *(const uint16_t *) l & (uint32_t) mask;
+        t *= 0x04040404U;
+        s = (int) (t >> 24) - 2 - 2;
+        break;
+    }
+    case TX_16X16: {
+        uint32_t t = (*(const uint32_t *) a & (uint32_t) mask) >> 6;
+        t         += (*(const uint32_t *) l & (uint32_t) mask) >> 6;
+        t *= (uint32_t) mul;
+        s = (int) (t >> 24) - 4 - 4;
+        break;
+    }
+    case TX_32X32: {
+        uint64_t t = (*(const uint64_t *) a & mask) >> 6;
+        t         += (*(const uint64_t *) l & mask) >> 6;
+        t *= mul;
+        s = (int) (t >> 56) - 8 - 8;
+        break;
+    }
+    case TX_64X64: {
+        uint64_t t = (*(const uint64_t *) &a[0] & mask) >> 6;
+        t         += (*(const uint64_t *) &a[8] & mask) >> 6;
+        t         += (*(const uint64_t *) &l[0] & mask) >> 6;
+        t         += (*(const uint64_t *) &l[8] & mask) >> 6;
+        t *= mul;
+        s = (int) (t >> 56) - 16 - 16;
+        break;
+    }
+    case RTX_4X8: {
+        uint32_t t = *(const uint8_t  *) a & (uint32_t) mask;
+        t         += *(const uint16_t *) l & (uint32_t) mask;
+        t *= 0x04040404U;
+        s = (int) (t >> 24) - 1 - 2;
+        break;
+    }
+    case RTX_8X4: {
+        uint32_t t = *(const uint16_t *) a & (uint32_t) mask;
+        t         += *(const uint8_t  *) l & (uint32_t) mask;
+        t *= 0x04040404U;
+        s = (int) (t >> 24) - 2 - 1;
+        break;
+    }
+    case RTX_8X16: {
+        uint32_t t = *(const uint16_t *) a & (uint32_t) mask;
+        t         += *(const uint32_t *) l & (uint32_t) mask;
+        t = (t >> 6) * (uint32_t) mul;
+        s = (int) (t >> 24) - 2 - 4;
+        break;
+    }
+    case RTX_16X8: {
+        uint32_t t = *(const uint32_t *) a & (uint32_t) mask;
+        t         += *(const uint16_t *) l & (uint32_t) mask;
+        t = (t >> 6) * (uint32_t) mul;
+        s = (int) (t >> 24) - 4 - 2;
+        break;
+    }
+    case RTX_16X32: {
+        uint64_t t = *(const uint32_t *) a & (uint32_t) mask;
+        t         += *(const uint64_t *) l & mask;
+        t = (t >> 6) * mul;
+        s = (int) (t >> 56) - 4 - 8;
+        break;
+    }
+    case RTX_32X16: {
+        uint64_t t = *(const uint64_t *) a & mask;
+        t         += *(const uint32_t *) l & (uint32_t) mask;
+        t = (t >> 6) * mul;
+        s = (int) (t >> 56) - 8 - 4;
+        break;
+    }
+    case RTX_32X64: {
+        uint64_t t = (*(const uint64_t *) &a[0] & mask) >> 6;
+        t         += (*(const uint64_t *) &l[0] & mask) >> 6;
+        t         += (*(const uint64_t *) &l[8] & mask) >> 6;
+        t *= mul;
+        s = (int) (t >> 56) - 8 - 16;
+        break;
+    }
+    case RTX_64X32: {
+        uint64_t t = (*(const uint64_t *) &a[0] & mask) >> 6;
+        t         += (*(const uint64_t *) &a[8] & mask) >> 6;
+        t         += (*(const uint64_t *) &l[0] & mask) >> 6;
+        t *= mul;
+        s = (int) (t >> 56) - 16 - 8;
+        break;
+    }
+    case RTX_4X16: {
+        uint32_t t = *(const uint8_t  *) a & (uint32_t) mask;
+        t         += *(const uint32_t *) l & (uint32_t) mask;
+        t = (t >> 6) * (uint32_t) mul;
+        s = (int) (t >> 24) - 1 - 4;
+        break;
+    }
+    case RTX_16X4: {
+        uint32_t t = *(const uint32_t *) a & (uint32_t) mask;
+        t         += *(const uint8_t  *) l & (uint32_t) mask;
+        t = (t >> 6) * (uint32_t) mul;
+        s = (int) (t >> 24) - 4 - 1;
+        break;
+    }
+    case RTX_8X32: {
+        uint64_t t = *(const uint16_t *) a & (uint32_t) mask;
+        t         += *(const uint64_t *) l & mask;
+        t = (t >> 6) * mul;
+        s = (int) (t >> 56) - 2 - 8;
+        break;
+    }
+    case RTX_32X8: {
+        uint64_t t = *(const uint64_t *) a & mask;
+        t         += *(const uint16_t *) l & (uint32_t) mask;
+        t = (t >> 6) * mul;
+        s = (int) (t >> 56) - 8 - 2;
+        break;
+    }
+    case RTX_16X64: {
+        uint64_t t = *(const uint32_t *) a & (uint32_t) mask;
+        t         += *(const uint64_t *) &l[0] & mask;
+        t = (t >> 6) + ((*(const uint64_t *) &l[8] & mask) >> 6);
+        t *= mul;
+        s = (int) (t >> 56) - 4 - 16;
+        break;
+    }
+    case RTX_64X16: {
+        uint64_t t = *(const uint64_t *) &a[0] & mask;
+        t         += *(const uint32_t *) l & (uint32_t) mask;
+        t = (t >> 6) + ((*(const uint64_t *) &a[8] & mask) >> 6);
+        t *= mul;
+        s = (int) (t >> 56) - 16 - 4;
+        break;
+    }
+    default: assert(0);
+    }
+
+    return (s != 0) + (s > 0);
+}
+
+static inline unsigned get_lo_ctx(const uint8_t *const levels,
+                                  const enum TxClass tx_class,
+                                  unsigned *const hi_mag,
+                                  const uint8_t (*const ctx_offsets)[5],
+                                  const unsigned x, const unsigned y,
+                                  const ptrdiff_t stride)
+{
+    unsigned mag = levels[0 * stride + 1] + levels[1 * stride + 0];
+    unsigned offset;
+    if (tx_class == TX_CLASS_2D) {
+        mag += levels[1 * stride + 1];
+        *hi_mag = mag;
+        mag += levels[0 * stride + 2] + levels[2 * stride + 0];
+        offset = ctx_offsets[umin(y, 4)][umin(x, 4)];
+    } else {
+        mag += levels[0 * stride + 2];
+        *hi_mag = mag;
+        mag += levels[0 * stride + 3] + levels[0 * stride + 4];
+        offset = 26 + (y > 1 ? 10 : y * 5);
+    }
+    return offset + (mag > 512 ? 4 : (mag + 64) >> 7);
+}
+
 static int decode_coefs(Dav1dTileContext *const t,
                         uint8_t *const a, uint8_t *const l,
                         const enum RectTxfmSize tx, const enum BlockSize bs,
@@ -74,7 +331,7 @@
         printf("Start: r=%d\n", ts->msac.rng);
 
     // does this block have any non-zero coefficients
-    const int sctx = get_coef_skip_ctx(t_dim, bs, a, l, chroma, f->cur.p.layout);
+    const int sctx = get_skip_ctx(t_dim, bs, a, l, chroma, f->cur.p.layout);
     const int all_skip = dav1d_msac_decode_bool_adapt(&ts->msac,
                              ts->cdf.coef.skip[t_dim->ctx][sctx]);
     if (dbg)
@@ -175,91 +432,126 @@
     }
 
     // base tokens
-    uint16_t (*const br_cdf)[4] =
-        ts->cdf.coef.br_tok[imin(t_dim->ctx, 3)][chroma];
+    uint16_t (*const eob_cdf)[4] = ts->cdf.coef.eob_base_tok[t_dim->ctx][chroma];
+    uint16_t (*const hi_cdf)[4] = ts->cdf.coef.br_tok[imin(t_dim->ctx, 3)][chroma];
     const uint16_t *const scan = dav1d_scans[tx][tx_class];
     int dc_tok;
 
     if (eob) {
-        uint8_t *const levels = t->scratch.levels;
+        uint16_t (*const lo_cdf)[4] = ts->cdf.coef.base_tok[t_dim->ctx][chroma];
+        uint8_t *const levels = t->scratch.levels; // bits 0-5: tok, 6-7: lo_tok
         const int sw = imin(t_dim->w, 8), sh = imin(t_dim->h, 8);
-        const ptrdiff_t stride = 4 * (sh + 1);
-        memset(levels, 0, stride * 4 * (sw + 1));
-        const int shift = 2 + imin(t_dim->lh, 3), mask = 4 * sh - 1;
+        const unsigned shift = 2 + imin(t_dim->lh, 3), mask = 4 * sh - 1;
 
-        { // eob
-            const int rc = scan[eob], x = rc >> shift, y = rc & mask;
+        /* eob */
+        unsigned rc = scan[eob], x = rc >> shift, y = rc & mask;
+        unsigned ctx = 1 + (eob > sw * sh * 2) + (eob > sw * sh * 4);
+        int eob_tok = dav1d_msac_decode_symbol_adapt4(&ts->msac, eob_cdf[ctx], 2);
+        int tok = eob_tok + 1;
+        int level_tok = tok * 0x41;
+        unsigned mag;
+        if (dbg)
+            printf("Post-lo_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
+                   t_dim->ctx, chroma, ctx, eob, rc, tok, ts->msac.rng);
 
-            const int ctx = 1 + (eob > sw * sh * 2) + (eob > sw * sh * 4);
-            uint16_t *const lo_cdf = ts->cdf.coef.eob_base_tok[t_dim->ctx][chroma][ctx];
+#define DECODE_COEFS_CLASS(tx_class) \
+        if (eob_tok == 2) { \
+            ctx = (tx_class == TX_CLASS_2D ? (x | y) > 1 : \
+                   tx_class == TX_CLASS_H ? x != 0 : y != 0) ? 14 : 7; \
+            tok = dav1d_msac_decode_hi_tok(&ts->msac, hi_cdf[ctx]); \
+            level_tok = tok + (3 << 6); \
+            if (dbg) \
+                printf("Post-hi_tok[%d][%d][%d][%d=%d=%d]: r=%d\n", \
+                       imin(t_dim->ctx, 3), chroma, ctx, eob, rc, tok, \
+                       ts->msac.rng); \
+        } \
+        cf[rc] = tok; \
+        if (tx_class == TX_CLASS_H) \
+            /* Transposing reduces the stride and padding requirements */ \
+            levels[y * stride + x] = (uint8_t) level_tok; \
+        else \
+            levels[x * stride + y] = (uint8_t) level_tok; \
+        for (int i = eob - 1; i > 0; i--) { /* ac */ \
+            if (tx_class == TX_CLASS_H) \
+                rc = i, x = rc & mask, y = rc >> shift; \
+            else \
+                rc = scan[i], x = rc >> shift, y = rc & mask; \
+            assert(x < 32 && y < 32); \
+            uint8_t *const level = levels + x * stride + y; \
+            ctx = get_lo_ctx(level, tx_class, &mag, lo_ctx_offsets, x, y, stride); \
+            if (tx_class == TX_CLASS_2D) \
+                y |= x; \
+            tok = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf[ctx], 3); \
+            level_tok = tok * 0x41; \
+            if (dbg) \
+                printf("Post-lo_tok[%d][%d][%d][%d=%d=%d]: r=%d\n", \
+                       t_dim->ctx, chroma, ctx, i, rc, tok, ts->msac.rng); \
+            if (tok == 3) { \
+                mag &= 63; \
+                ctx = (y > (tx_class == TX_CLASS_2D) ? 14 : 7) + \
+                      (mag > 12 ? 6 : (mag + 1) >> 1); \
+                tok = dav1d_msac_decode_hi_tok(&ts->msac, hi_cdf[ctx]); \
+                level_tok = tok + (3 << 6); \
+                if (dbg) \
+                    printf("Post-hi_tok[%d][%d][%d][%d=%d=%d]: r=%d\n", \
+                           imin(t_dim->ctx, 3), chroma, ctx, i, rc, tok, \
+                           ts->msac.rng); \
+            } \
+            cf[rc] = tok; \
+            *level = (uint8_t) level_tok; \
+        } \
+        /* dc */ \
+        ctx = (tx_class == TX_CLASS_2D) ? 0 : \
+            get_lo_ctx(levels, tx_class, &mag, lo_ctx_offsets, 0, 0, stride); \
+        dc_tok = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf[ctx], 3); \
+        if (dbg) \
+            printf("Post-dc_lo_tok[%d][%d][%d][%d]: r=%d\n", \
+                   t_dim->ctx, chroma, ctx, dc_tok, ts->msac.rng); \
+        if (dc_tok == 3) { \
+            if (tx_class == TX_CLASS_2D) \
+                mag = levels[0 * stride + 1] + levels[1 * stride + 0] + \
+                      levels[1 * stride + 1]; \
+            mag &= 63; \
+            ctx = mag > 12 ? 6 : (mag + 1) >> 1; \
+            dc_tok = dav1d_msac_decode_hi_tok(&ts->msac, hi_cdf[ctx]); \
+            if (dbg) \
+                printf("Post-dc_hi_tok[%d][%d][0][%d]: r=%d\n", \
+                       imin(t_dim->ctx, 3), chroma, dc_tok, ts->msac.rng); \
+        } \
+        break
 
-            int tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf, 2);
-            int tok = 1 + tok_br;
-            if (dbg)
-                printf("Post-lo_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
-                       t_dim->ctx, chroma, ctx, eob, rc, tok, ts->msac.rng);
-
-            if (tok_br == 2) {
-                const int br_ctx = get_br_ctx(levels, 1, tx_class, x, y, stride);
-                tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[br_ctx]);
-                if (dbg)
-                    printf("Post-hi_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
-                           imin(t_dim->ctx, 3), chroma, br_ctx, eob, rc, tok,
-                           ts->msac.rng);
-            }
-
-            cf[rc] = tok;
-            levels[x * stride + y] = (uint8_t) tok;
+        switch (tx_class) {
+        case TX_CLASS_2D: {
+            const unsigned nonsquare_tx = tx >= RTX_4X8;
+            const uint8_t (*const lo_ctx_offsets)[5] =
+                dav1d_lo_ctx_offsets[nonsquare_tx + (tx & nonsquare_tx)];
+            const ptrdiff_t stride = 4 * sh;
+            memset(levels, 0, stride * (4 * sw + 2));
+            DECODE_COEFS_CLASS(TX_CLASS_2D);
         }
-        for (int i = eob - 1; i > 0; i--) { // ac
-            const int rc = scan[i], x = rc >> shift, y = rc & mask;
-
-            // lo tok
-            const int ctx = get_coef_nz_ctx(levels, tx, tx_class, x, y, stride);
-            uint16_t *const lo_cdf = ts->cdf.coef.base_tok[t_dim->ctx][chroma][ctx];
-            int tok = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf, 3);
-            if (dbg)
-                printf("Post-lo_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
-                       t_dim->ctx, chroma, ctx, i, rc, tok, ts->msac.rng);
-
-            if (tok == 3) {
-                const int br_ctx = get_br_ctx(levels, 1, tx_class, x, y, stride);
-                tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[br_ctx]);
-                if (dbg)
-                    printf("Post-hi_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
-                           imin(t_dim->ctx, 3), chroma, br_ctx, i, rc, tok,
-                           ts->msac.rng);
-            }
-            cf[rc] = tok;
-            levels[x * stride + y] = (uint8_t) tok;
+        case TX_CLASS_H: {
+#define lo_ctx_offsets NULL
+            const ptrdiff_t stride = 16;
+            memset(levels, 0, stride * (4 * sh + 2));
+            DECODE_COEFS_CLASS(TX_CLASS_H);
         }
-        { // dc
-            const int ctx = (tx_class != TX_CLASS_2D) ?
-                get_coef_nz_ctx(levels, tx, tx_class, 0, 0, stride) : 0;
-            uint16_t *const lo_cdf = ts->cdf.coef.base_tok[t_dim->ctx][chroma][ctx];
-            dc_tok = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf, 3);
-            if (dbg)
-                printf("Post-dc_lo_tok[%d][%d][%d][%d]: r=%d\n",
-                       t_dim->ctx, chroma, ctx, dc_tok, ts->msac.rng);
-
-            if (dc_tok == 3) {
-                const int br_ctx = get_br_ctx(levels, 0, tx_class, 0, 0, stride);
-                dc_tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[br_ctx]);
-                if (dbg)
-                    printf("Post-dc_hi_tok[%d][%d][0][%d]: r=%d\n",
-                           imin(t_dim->ctx, 3), chroma, dc_tok, ts->msac.rng);
-            }
+        case TX_CLASS_V: {
+            const ptrdiff_t stride = 16;
+            memset(levels, 0, stride * (4 * sw + 2));
+            DECODE_COEFS_CLASS(TX_CLASS_V);
         }
+#undef lo_ctx_offsets
+#undef DECODE_COEFS_CLASS
+        default: assert(0);
+        }
     } else { // dc-only
-        uint16_t *const lo_cdf = ts->cdf.coef.eob_base_tok[t_dim->ctx][chroma][0];
-        int tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf, 2);
+        int tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, eob_cdf[0], 2);
         dc_tok = 1 + tok_br;
         if (dbg)
             printf("Post-dc_lo_tok[%d][%d][%d][%d]: r=%d\n",
                    t_dim->ctx, chroma, 0, dc_tok, ts->msac.rng);
-
         if (tok_br == 2) {
-            dc_tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[0]);
+            dc_tok = dav1d_msac_decode_hi_tok(&ts->msac, hi_cdf[0]);
             if (dbg)
                 printf("Post-dc_hi_tok[%d][%d][0][%d]: r=%d\n",
                        imin(t_dim->ctx, 3), chroma, dc_tok, ts->msac.rng);
@@ -276,7 +568,7 @@
     unsigned cul_level = 0;
 
     if (dc_tok) { // dc
-        const int dc_sign_ctx = get_dc_sign_ctx(t_dim, a, l);
+        const int dc_sign_ctx = get_dc_sign_ctx(tx, a, l);
         uint16_t *const dc_sign_cdf =
             ts->cdf.coef.dc_sign[chroma][dc_sign_ctx];
         const int sign = dav1d_msac_decode_bool_adapt(&ts->msac, dc_sign_cdf);
@@ -328,7 +620,7 @@
     }
 
     // context
-    *res_ctx = imin(cul_level, 63) | dc_sign;
+    *res_ctx = umin(cul_level, 63) | dc_sign;
 
     return eob;
 }
--- a/src/tables.c
+++ b/src/tables.c
@@ -272,119 +272,34 @@
     [BS_4x4  ]   = 0,
 };
 
-const uint8_t dav1d_nz_map_ctx_offset[N_RECT_TX_SIZES][5][5] = {
-    [TX_4X4] = {
-        { 0, 1, 6, 6 },
-        { 1, 6, 6, 21 },
-        { 6, 6, 21, 21 },
-        { 6, 21, 21, 21 },
-    }, [TX_8X8] = {
-        { 0, 1, 6, 6, 21 },
-        { 1, 6, 6, 21, 21 },
-        { 6, 6, 21, 21, 21 },
-        { 6, 21, 21, 21, 21 },
-        { 21, 21, 21, 21, 21 }
-    }, [TX_16X16] = {
-        { 0, 1, 6, 6, 21 },
-        { 1, 6, 6, 21, 21 },
-        { 6, 6, 21, 21, 21 },
-        { 6, 21, 21, 21, 21 },
-        { 21, 21, 21, 21, 21 }
-    }, [TX_32X32] = {
-        { 0, 1, 6, 6, 21 },
-        { 1, 6, 6, 21, 21 },
-        { 6, 6, 21, 21, 21 },
-        { 6, 21, 21, 21, 21 },
-        { 21, 21, 21, 21, 21 }
-    }, [TX_64X64] = {
-        { 0, 1, 6, 6, 21 },
-        { 1, 6, 6, 21, 21 },
-        { 6, 6, 21, 21, 21 },
-        { 6, 21, 21, 21, 21 },
-        { 21, 21, 21, 21, 21 }
-    }, [RTX_4X8] = {
-        { 0, 11, 11, 11 },
-        { 11, 11, 11, 11 },
-        { 6, 6, 21, 21 },
-        { 6, 21, 21, 21 },
-        { 21, 21, 21, 21 }
-    }, [RTX_8X4] = {
-        { 0, 16, 6, 6, 21 },
-        { 16, 16, 6, 21, 21 },
+const uint8_t dav1d_lo_ctx_offsets[3][5][5] = {
+    { /* w == h */
+        {  0,  1,  6,  6, 21 },
+        {  1,  6,  6, 21, 21 },
+        {  6,  6, 21, 21, 21 },
+        {  6, 21, 21, 21, 21 },
+        { 21, 21, 21, 21, 21 },
+    }, { /* w > h */
+        {  0, 16,  6,  6, 21 },
+        { 16, 16,  6, 21, 21 },
         { 16, 16, 21, 21, 21 },
         { 16, 16, 21, 21, 21 },
-    }, [RTX_8X16] = {
-        { 0, 11, 11, 11, 11 },
-        { 11, 11, 11, 11, 11 },
-        { 6, 6, 21, 21, 21 },
-        { 6, 21, 21, 21, 21 },
-        { 21, 21, 21, 21, 21 }
-    }, [RTX_16X8] = {
-        { 0, 16, 6, 6, 21 },
-        { 16, 16, 6, 21, 21 },
         { 16, 16, 21, 21, 21 },
-        { 16, 16, 21, 21, 21 },
-        { 16, 16, 21, 21, 21 }
-    }, [RTX_16X32] = {
-        { 0, 11, 11, 11, 11 },
+    }, { /* w < h */
+        {  0, 11, 11, 11, 11 },
         { 11, 11, 11, 11, 11 },
-        { 6, 6, 21, 21, 21 },
-        { 6, 21, 21, 21, 21 },
-        { 21, 21, 21, 21, 21 }
-    }, [RTX_32X16] = {
-        { 0, 16, 6, 6, 21 },
-        { 16, 16, 6, 21, 21 },
-        { 16, 16, 21, 21, 21 },
-        { 16, 16, 21, 21, 21 },
-        { 16, 16, 21, 21, 21 }
-    }, [RTX_32X64] = {
-        { 0, 11, 11, 11, 11 },
-        { 11, 11, 11, 11, 11 },
-        { 6, 6, 21, 21, 21 },
-        { 6, 21, 21, 21, 21 },
-        { 21, 21, 21, 21, 21 }
-    }, [RTX_64X32] = {
-        { 0, 16, 6, 6, 21 },
-        { 16, 16, 6, 21, 21 },
-        { 16, 16, 21, 21, 21 },
-        { 16, 16, 21, 21, 21 },
-        { 16, 16, 21, 21, 21 }
-    }, [RTX_4X16] = {
-        { 0, 11, 11, 11 },
-        { 11, 11, 11, 11 },
-        { 6, 6, 21, 21 },
-        { 6, 21, 21, 21 },
-        { 21, 21, 21, 21 }
-    }, [RTX_16X4] = {
-        { 0, 16, 6, 6, 21 },
-        { 16, 16, 6, 21, 21 },
-        { 16, 16, 21, 21, 21 },
-        { 16, 16, 21, 21, 21 },
-    }, [RTX_8X32] = {
-        { 0, 11, 11, 11, 11 },
-        { 11, 11, 11, 11, 11 },
-        { 6, 6, 21, 21, 21 },
-        { 6, 21, 21, 21, 21 },
-        { 21, 21, 21, 21, 21 }
-    }, [RTX_32X8] = {
-        { 0, 16, 6, 6, 21 },
-        { 16, 16, 6, 21, 21 },
-        { 16, 16, 21, 21, 21 },
-        { 16, 16, 21, 21, 21 },
-        { 16, 16, 21, 21, 21 }
-    }, [RTX_16X64] = {
-        { 0, 11, 11, 11, 11 },
-        { 11, 11, 11, 11, 11 },
-        { 6, 6, 21, 21, 21 },
-        { 6, 21, 21, 21, 21 },
-        { 21, 21, 21, 21, 21 }
-    }, [RTX_64X16] = {
-        { 0, 16, 6, 6, 21 },
-        { 16, 16, 6, 21, 21 },
-        { 16, 16, 21, 21, 21 },
-        { 16, 16, 21, 21, 21 },
-        { 16, 16, 21, 21, 21 }
-    }
+        {  6,  6, 21, 21, 21 },
+        {  6, 21, 21, 21, 21 },
+        { 21, 21, 21, 21, 21 },
+    },
+};
+
+const uint8_t dav1d_skip_ctx[5][5] = {
+    { 1, 2, 2, 2, 3 },
+    { 2, 4, 4, 4, 5 },
+    { 2, 4, 4, 4, 5 },
+    { 2, 4, 4, 4, 5 },
+    { 3, 5, 5, 5, 6 },
 };
 
 const uint8_t /* enum TxClass */ dav1d_tx_type_class[N_TX_TYPES_PLUS_LL] = {
--- a/src/tables.h
+++ b/src/tables.h
@@ -57,7 +57,8 @@
 
 extern const uint8_t dav1d_filter_mode_to_y_mode[5];
 extern const uint8_t dav1d_ymode_size_context[N_BS_SIZES];
-extern const uint8_t dav1d_nz_map_ctx_offset[N_RECT_TX_SIZES][5][5];
+extern const uint8_t dav1d_lo_ctx_offsets[3][5][5];
+extern const uint8_t dav1d_skip_ctx[5][5];
 extern const uint8_t /* enum TxClass */
                      dav1d_tx_type_class[N_TX_TYPES_PLUS_LL];
 extern const uint8_t /* enum Filter2d */