shithub: dav1d

Download patch

ref: e29fd5c0016fec27c88a36ac6f6eaaf416d91330
parent: a819653e1b71ea69c13faaa64c5bb89534ce2772
author: Henrik Gramner <[email protected]>
date: Tue Aug 6 11:17:31 EDT 2019

Add msac optimizations

 * Eliminate the trailing zero after the CDF probabilities. We can
   reuse the count value as a terminator instead. This reduces the
   size of the CDF context by around 8%.

 * Align the CDF arrays.

 * Various other minor optimizations.

--- a/src/arm/64/msac.S
+++ b/src/arm/64/msac.S
@@ -148,7 +148,7 @@
         add             x8,  x0,  #RNG
         ld1_n           v0,  v1,  x1,  \sz, \n                    // cdf
         ld1r            {v4\sz},  [x8]                            // rng
-        movrel          x9,  coeffs, 32
+        movrel          x9,  coeffs, 30
         sub             x9,  x9,  x2, lsl #1
         ushr_n          v2,  v3,  v0,  v1,  #6, \sz, \n           // cdf >> EC_PROB_SHIFT
         str             h4,  [sp, #14]                            // store original u = s->rng
@@ -183,16 +183,24 @@
         // update_cdf
         ldrh            w3,  [x1, x2, lsl #1]                     // count = cdf[n_symbols]
         movi            v5\szb, #0xff
-        cmp             x2,  #4                                   // set C if n_symbols >= 4 (n_symbols > 3)
-        mov             w14, #4
-        lsr             w4,  w3,  #4                              // count >> 4
+.if \n == 16
+        mov             w4,  #-5
+.else
+        mvn             w14, w2
+        mov             w4,  #-4
+        cmn             w14, #3                                   // set C if n_symbols <= 2
+.endif
         urhadd_n        v4,  v5,  v5,  v5,  v2,  v3,  \sz, \n     // i >= val ? -1 : 32768
-        adc             w4,  w4,  w14                             // (count >> 4) + (n_symbols > 3) + 4
-        neg             w4,  w4                                   // -rate
+.if \n == 16
+        sub             w4,  w4,  w3, lsr #4                      // -((count >> 4) + 5)
+.else
+        lsr             w14, w3,  #4                              // count >> 4
+        sbc             w4,  w4,  w14                             // -((count >> 4) + (n_symbols > 2) + 4)
+.endif
         sub_n           v4,  v5,  v4,  v5,  v0,  v1,  \sz, \n     // (32768 - cdf[i]) or (-1 - cdf[i])
         dup             v6.8h,    w4                              // -rate
 
-        sub             w3,  w3,  w3, lsr #5                      // count - (count >= 32)
+        sub             w3,  w3,  w3, lsr #5                      // count - (count == 32)
         sub_n           v0,  v1,  v0,  v1,  v2,  v3,  \sz, \n     // cdf + (i >= val ? 1 : 0)
         sshl_n          v4,  v5,  v4,  v5,  v6,  v6,  \sz, \n     // ({32768,-1} - cdf[i]) >> rate
         add             w3,  w3,  #1                              // count + (count < 32)
@@ -224,8 +232,7 @@
         b.ge            9f
 
         // refill
-        ldr             x3,  [x0, #BUF_POS]
-        ldr             x4,  [x0, #BUF_END]
+        ldp             x3,  x4,  [x0]         // BUF_POS, BUF_END
         add             x5,  x3,  #8
         cmp             x5,  x4
         b.gt            2f
--- a/src/cdf.c
+++ b/src/cdf.c
@@ -34,6 +34,7 @@
 #include "common/intops.h"
 
 #include "src/cdf.h"
+#include "src/tables.h"
 
 #define AOM_ICDF(x) (32768-(x))
 
@@ -752,12 +753,11 @@
     }
 };
 
-
-static const uint16_t default_mv_joint_cdf[N_MV_JOINTS + 1] = {
+static const uint16_t ALIGN(default_mv_joint_cdf[N_MV_JOINTS], 8) = {
     AOM_CDF4(4096, 11264, 19328)
 };
 
-static const uint16_t default_kf_y_mode_cdf[5][5][N_INTRA_PRED_MODES + 1 + 2] = {
+static const uint16_t ALIGN(default_kf_y_mode_cdf[5][5][N_INTRA_PRED_MODES + 3], 32) = {
     {
         { AOM_CDF13(15588, 17027, 19338, 20218, 20682, 21110, 21825, 23244,
                     24189, 28165, 29093, 30466) },
@@ -3927,25 +3927,18 @@
                              CdfContext *const dst,
                              const CdfContext *const src)
 {
-    int i, j, k, l;
-
 #define update_cdf_1d(n1d, name) \
     do { \
-        memcpy(dst->name, src->name, sizeof(*dst->name) * n1d); \
-        assert(!dst->name[n1d - 1]); \
+        memcpy(dst->name, src->name, sizeof(dst->name)); \
         dst->name[n1d] = 0; \
     } while (0)
 
 #define update_cdf_2d(n1d, n2d, name) \
-    for (j = 0; j < (n1d); j++) update_cdf_1d(n2d, name[j])
+    for (int j = 0; j < (n1d); j++) update_cdf_1d(n2d, name[j])
 #define update_cdf_3d(n1d, n2d, n3d, name) \
-    for (k = 0; k < (n1d); k++) update_cdf_2d(n2d, n3d, name[k])
+    for (int k = 0; k < (n1d); k++) update_cdf_2d(n2d, n3d, name[k])
 #define update_cdf_4d(n1d, n2d, n3d, n4d, name) \
-    for (l = 0; l < (n1d); l++) update_cdf_3d(n2d, n3d, n4d, name[l])
-#define update_cdf_6d(n1d, n2d, n3d, n4d, n5d, n6d, name) \
-    for (n = 0; n < (n1d); n++) \
-        for (m = 0; m < (n2d); m++) \
-            update_cdf_4d(n3d, n4d, n5d, n6d, name[n][m])
+    for (int l = 0; l < (n1d); l++) update_cdf_3d(n2d, n3d, n4d, name[l])
 
 #define update_bit_0d(name) \
     do { \
@@ -3954,65 +3947,57 @@
     } while (0)
 
 #define update_bit_1d(n1d, name) \
-    for (i = 0; i < (n1d); i++) update_bit_0d(name[i])
+    for (int i = 0; i < (n1d); i++) update_bit_0d(name[i])
 #define update_bit_2d(n1d, n2d, name) \
-    for (j = 0; j < (n1d); j++) update_bit_1d(n2d, name[j])
+    for (int j = 0; j < (n1d); j++) update_bit_1d(n2d, name[j])
 #define update_bit_3d(n1d, n2d, n3d, name) \
-    for (k = 0; k < (n1d); k++) update_bit_2d(n2d, n3d, name[k])
+    for (int k = 0; k < (n1d); k++) update_bit_2d(n2d, n3d, name[k])
 
     update_bit_1d(N_BS_SIZES, m.use_filter_intra);
-    update_cdf_1d(5, m.filter_intra);
-    update_cdf_3d(2, N_INTRA_PRED_MODES, N_UV_INTRA_PRED_MODES - !k, m.uv_mode);
-    update_cdf_2d(8, 7, m.angle_delta);
-    update_cdf_3d(N_TX_SIZES - 1, 3, imin(k + 2, 3), m.txsz);
-    update_cdf_3d(2, N_INTRA_PRED_MODES, 7, m.txtp_intra1);
-    update_cdf_3d(3, N_INTRA_PRED_MODES, 5, m.txtp_intra2);
+    update_cdf_1d(4, m.filter_intra);
+    update_cdf_3d(2, N_INTRA_PRED_MODES, N_UV_INTRA_PRED_MODES - 1 - !k, m.uv_mode);
+    update_cdf_2d(8, 6, m.angle_delta);
+    update_cdf_3d(N_TX_SIZES - 1, 3, imin(k + 1, 2), m.txsz);
+    update_cdf_3d(2, N_INTRA_PRED_MODES, 6, m.txtp_intra1);
+    update_cdf_3d(3, N_INTRA_PRED_MODES, 4, m.txtp_intra2);
     update_bit_1d(3, m.skip);
-    static const uint8_t n_partitions[N_BL_LEVELS] = {
-        [BL_128X128] = N_PARTITIONS - 2,
-        [BL_64X64]   = N_PARTITIONS,
-        [BL_32X32]   = N_PARTITIONS,
-        [BL_16X16]   = N_PARTITIONS,
-        [BL_8X8]     = N_SUB8X8_PARTITIONS,
-    };
-    update_cdf_3d(N_BL_LEVELS, 4, n_partitions[k], m.partition);
+    update_cdf_3d(N_BL_LEVELS, 4, dav1d_partition_type_count[k], m.partition);
     update_bit_2d(N_TX_SIZES, 13, coef.skip);
-    update_cdf_3d(2, 2, 5, coef.eob_bin_16);
-    update_cdf_3d(2, 2, 6, coef.eob_bin_32);
-    update_cdf_3d(2, 2, 7, coef.eob_bin_64);
-    update_cdf_3d(2, 2, 8, coef.eob_bin_128);
-    update_cdf_3d(2, 2, 9, coef.eob_bin_256);
-    update_cdf_2d(2, 10, coef.eob_bin_512);
-    update_cdf_2d(2, 11, coef.eob_bin_1024);
+    update_cdf_3d(2, 2, 4, coef.eob_bin_16);
+    update_cdf_3d(2, 2, 5, coef.eob_bin_32);
+    update_cdf_3d(2, 2, 6, coef.eob_bin_64);
+    update_cdf_3d(2, 2, 7, coef.eob_bin_128);
+    update_cdf_3d(2, 2, 8, coef.eob_bin_256);
+    update_cdf_2d(2, 9, coef.eob_bin_512);
+    update_cdf_2d(2, 10, coef.eob_bin_1024);
     update_bit_3d(N_TX_SIZES, 2, 11 /*22*/, coef.eob_hi_bit);
-    update_cdf_4d(N_TX_SIZES, 2, 4, 3, coef.eob_base_tok);
-    update_cdf_4d(N_TX_SIZES, 2, 41 /*42*/, 4, coef.base_tok);
+    update_cdf_4d(N_TX_SIZES, 2, 4, 2, coef.eob_base_tok);
+    update_cdf_4d(N_TX_SIZES, 2, 41 /*42*/, 3, coef.base_tok);
     update_bit_2d(2, 3, coef.dc_sign);
-    update_cdf_4d(4, 2, 21, 4, coef.br_tok);
-    update_cdf_2d(3, DAV1D_MAX_SEGMENTS, m.seg_id);
-    update_cdf_1d(8, m.cfl_sign);
-    update_cdf_2d(6, 16, m.cfl_alpha);
+    update_cdf_4d(4, 2, 21, 3, coef.br_tok);
+    update_cdf_2d(3, DAV1D_MAX_SEGMENTS - 1, m.seg_id);
+    update_cdf_1d(7, m.cfl_sign);
+    update_cdf_2d(6, 15, m.cfl_alpha);
     update_bit_0d(m.restore_wiener);
     update_bit_0d(m.restore_sgrproj);
-    update_cdf_1d(3, m.restore_switchable);
-    update_cdf_1d(4, m.delta_q);
-    update_cdf_2d(5, 4, m.delta_lf);
+    update_cdf_1d(2, m.restore_switchable);
+    update_cdf_1d(3, m.delta_q);
+    update_cdf_2d(5, 3, m.delta_lf);
     update_bit_2d(7, 3, m.pal_y);
     update_bit_1d(2, m.pal_uv);
-    update_cdf_3d(2, 7, 7, m.pal_sz);
-    update_cdf_4d(2, 7, 5, k + 2, m.color_map);
-
+    update_cdf_3d(2, 7, 6, m.pal_sz);
+    update_cdf_4d(2, 7, 5, k + 1, m.color_map);
     update_bit_2d(7, 3, m.txpart);
-    update_cdf_2d(2, 16, m.txtp_inter1);
-    update_cdf_1d(12, m.txtp_inter2);
+    update_cdf_2d(2, 15, m.txtp_inter1);
+    update_cdf_1d(11, m.txtp_inter2);
     update_bit_1d(4, m.txtp_inter3);
 
     if (!(hdr->frame_type & 1)) {
         update_bit_0d(m.intrabc);
 
-        update_cdf_1d(N_MV_JOINTS, dmv.joint);
-        for (k = 0; k < 2; k++) {
-            update_cdf_1d(11, dmv.comp[k].classes);
+        update_cdf_1d(N_MV_JOINTS - 1, dmv.joint);
+        for (int k = 0; k < 2; k++) {
+            update_cdf_1d(10, dmv.comp[k].classes);
             update_bit_0d(dmv.comp[k].class0);
             update_bit_1d(10, dmv.comp[k].classN);
             update_bit_0d(dmv.comp[k].sign);
@@ -4021,13 +4006,13 @@
     }
 
     update_bit_1d(3, m.skip_mode);
-    update_cdf_2d(4, N_INTRA_PRED_MODES, m.y_mode);
-    update_cdf_3d(2, 8, DAV1D_N_SWITCHABLE_FILTERS, m.filter);
+    update_cdf_2d(4, N_INTRA_PRED_MODES - 1, m.y_mode);
+    update_cdf_3d(2, 8, DAV1D_N_SWITCHABLE_FILTERS - 1, m.filter);
     update_bit_1d(6, m.newmv_mode);
     update_bit_1d(2, m.globalmv_mode);
     update_bit_1d(6, m.refmv_mode);
     update_bit_1d(3, m.drl_bit);
-    update_cdf_2d(8, N_COMP_INTER_PRED_MODES, m.comp_inter_mode);
+    update_cdf_2d(8, N_COMP_INTER_PRED_MODES - 1, m.comp_inter_mode);
     update_bit_1d(4, m.intra);
     update_bit_1d(5, m.comp);
     update_bit_1d(5, m.comp_dir);
@@ -4034,7 +4019,7 @@
     update_bit_1d(6, m.jnt_comp);
     update_bit_1d(6, m.mask_comp);
     update_bit_1d(9, m.wedge_comp);
-    update_cdf_2d(9, 16, m.wedge_idx);
+    update_cdf_2d(9, 15, m.wedge_idx);
     update_bit_2d(6, 3, m.ref);
     update_bit_2d(3, 3, m.comp_fwd_ref);
     update_bit_2d(2, 3, m.comp_bwd_ref);
@@ -4042,17 +4027,17 @@
     update_bit_1d(3, m.seg_pred);
     update_bit_1d(4, m.interintra);
     update_bit_1d(7, m.interintra_wedge);
-    update_cdf_2d(4, 4, m.interintra_mode);
-    update_cdf_2d(N_BS_SIZES, 3, m.motion_mode);
+    update_cdf_2d(4, 3, m.interintra_mode);
+    update_cdf_2d(N_BS_SIZES, 2, m.motion_mode);
     update_bit_1d(N_BS_SIZES, m.obmc);
 
-    update_cdf_1d(N_MV_JOINTS, mv.joint);
-    for (k = 0; k < 2; k++) {
-        update_cdf_1d(11, mv.comp[k].classes);
+    update_cdf_1d(N_MV_JOINTS - 1, mv.joint);
+    for (int k = 0; k < 2; k++) {
+        update_cdf_1d(10, mv.comp[k].classes);
         update_bit_0d(mv.comp[k].class0);
         update_bit_1d(10, mv.comp[k].classN);
-        update_cdf_2d(2, 4, mv.comp[k].class0_fp);
-        update_cdf_1d(4, mv.comp[k].classN_fp);
+        update_cdf_2d(2, 3, mv.comp[k].class0_fp);
+        update_cdf_1d(3, mv.comp[k].classN_fp);
         update_bit_0d(mv.comp[k].class0_hp);
         update_bit_0d(mv.comp[k].classN_hp);
         update_bit_0d(mv.comp[k].sign);
@@ -4062,7 +4047,7 @@
 /*
  * CDF threading wrappers.
  */
-static inline int get_qcat_idx(int q) {
+static inline int get_qcat_idx(const int q) {
     if (q <= 20) return 0;
     if (q <= 60) return 1;
     if (q <= 120) return 2;
@@ -4089,7 +4074,7 @@
 }
 
 int dav1d_cdf_thread_alloc(CdfThreadContext *const cdf,
-                            struct thread_data *const t)
+                           struct thread_data *const t)
 {
     cdf->ref = dav1d_ref_create(sizeof(CdfContext) +
                                 (t != NULL) * sizeof(atomic_uint));
--- a/src/cdf.h
+++ b/src/cdf.h
@@ -37,94 +37,94 @@
 /* Buffers padded to [8] or [16] for SIMD where needed. */
 
 typedef struct CdfModeContext {
-    uint16_t y_mode[4][N_INTRA_PRED_MODES + 1 + 2];
-    uint16_t use_filter_intra[N_BS_SIZES][2];
-    uint16_t filter_intra[5 + 1];
-    uint16_t uv_mode[2][N_INTRA_PRED_MODES][N_UV_INTRA_PRED_MODES + 1 + 1];
-    uint16_t angle_delta[8][8];
-    uint16_t filter[2][8][DAV1D_N_SWITCHABLE_FILTERS + 1];
-    uint16_t newmv_mode[6][2];
-    uint16_t globalmv_mode[2][2];
-    uint16_t refmv_mode[6][2];
-    uint16_t drl_bit[3][2];
-    uint16_t comp_inter_mode[8][N_COMP_INTER_PRED_MODES + 1];
-    uint16_t intra[4][2];
-    uint16_t comp[5][2];
-    uint16_t comp_dir[5][2];
-    uint16_t jnt_comp[6][2];
-    uint16_t mask_comp[6][2];
-    uint16_t wedge_comp[9][2];
-    uint16_t wedge_idx[9][16 + 1];
-    uint16_t interintra[7][2];
-    uint16_t interintra_mode[4][5];
-    uint16_t interintra_wedge[7][2];
-    uint16_t ref[6][3][2];
-    uint16_t comp_fwd_ref[3][3][2];
-    uint16_t comp_bwd_ref[2][3][2];
-    uint16_t comp_uni_ref[3][3][2];
-    uint16_t txsz[N_TX_SIZES - 1][3][4];
-    uint16_t txpart[7][3][2];
-    uint16_t txtp_inter1[2][16 + 1];
-    uint16_t txtp_inter2[12 + 1 + 3];
-    uint16_t txtp_inter3[4][2];
-    uint16_t txtp_intra1[2][N_INTRA_PRED_MODES][7 + 1];
-    uint16_t txtp_intra2[3][N_INTRA_PRED_MODES][5 + 1 + 2];
-    uint16_t skip[3][2];
-    uint16_t skip_mode[3][2];
-    uint16_t partition[N_BL_LEVELS][4][N_PARTITIONS + 1 + 5];
-    uint16_t seg_pred[3][2];
-    uint16_t seg_id[3][DAV1D_MAX_SEGMENTS + 1];
-    uint16_t cfl_sign[8 + 1];
-    uint16_t cfl_alpha[6][16 + 1];
-    uint16_t restore_wiener[2];
-    uint16_t restore_sgrproj[2];
-    uint16_t restore_switchable[3 + 1];
-    uint16_t delta_q[4 + 1];
-    uint16_t delta_lf[5][4 + 1];
-    uint16_t obmc[N_BS_SIZES][2];
-    uint16_t motion_mode[N_BS_SIZES][3 + 1];
-    uint16_t pal_y[7][3][2];
-    uint16_t pal_uv[2][2];
-    uint16_t pal_sz[2][7][7 + 1];
-    uint16_t color_map[2][7][5][8 + 1];
-    uint16_t intrabc[2];
+    ALIGN(uint16_t y_mode[4][N_INTRA_PRED_MODES + 3], 32);
+    ALIGN(uint16_t uv_mode[2][N_INTRA_PRED_MODES][N_UV_INTRA_PRED_MODES + 2], 32);
+    ALIGN(uint16_t wedge_idx[9][16], 32);
+    ALIGN(uint16_t partition[N_BL_LEVELS][4][N_PARTITIONS + 6], 32);
+    ALIGN(uint16_t cfl_alpha[6][16], 32);
+    ALIGN(uint16_t txtp_inter1[2][16], 32);
+    ALIGN(uint16_t txtp_inter2[12 + 4], 32);
+    ALIGN(uint16_t txtp_intra1[2][N_INTRA_PRED_MODES][7 + 1], 16);
+    ALIGN(uint16_t txtp_intra2[3][N_INTRA_PRED_MODES][5 + 3], 16);
+    ALIGN(uint16_t cfl_sign[8], 16);
+    ALIGN(uint16_t angle_delta[8][8], 16);
+    ALIGN(uint16_t filter_intra[5 + 3], 16);
+    ALIGN(uint16_t comp_inter_mode[8][N_COMP_INTER_PRED_MODES], 16);
+    ALIGN(uint16_t seg_id[3][DAV1D_MAX_SEGMENTS], 16);
+    ALIGN(uint16_t pal_sz[2][7][7 + 1], 16);
+    ALIGN(uint16_t color_map[2][7][5][8], 16);
+    ALIGN(uint16_t filter[2][8][DAV1D_N_SWITCHABLE_FILTERS + 1], 8);
+    ALIGN(uint16_t txsz[N_TX_SIZES - 1][3][4], 8);
+    ALIGN(uint16_t motion_mode[N_BS_SIZES][3 + 1], 8);
+    ALIGN(uint16_t delta_q[4], 8);
+    ALIGN(uint16_t delta_lf[5][4], 8);
+    ALIGN(uint16_t interintra_mode[4][4], 8);
+    ALIGN(uint16_t restore_switchable[3 + 1], 8);
+    ALIGN(uint16_t restore_wiener[2], 4);
+    ALIGN(uint16_t restore_sgrproj[2], 4);
+    ALIGN(uint16_t interintra[7][2], 4);
+    ALIGN(uint16_t interintra_wedge[7][2], 4);
+    ALIGN(uint16_t txtp_inter3[4][2], 4);
+    ALIGN(uint16_t use_filter_intra[N_BS_SIZES][2], 4);
+    ALIGN(uint16_t newmv_mode[6][2], 4);
+    ALIGN(uint16_t globalmv_mode[2][2], 4);
+    ALIGN(uint16_t refmv_mode[6][2], 4);
+    ALIGN(uint16_t drl_bit[3][2], 4);
+    ALIGN(uint16_t intra[4][2], 4);
+    ALIGN(uint16_t comp[5][2], 4);
+    ALIGN(uint16_t comp_dir[5][2], 4);
+    ALIGN(uint16_t jnt_comp[6][2], 4);
+    ALIGN(uint16_t mask_comp[6][2], 4);
+    ALIGN(uint16_t wedge_comp[9][2], 4);
+    ALIGN(uint16_t ref[6][3][2], 4);
+    ALIGN(uint16_t comp_fwd_ref[3][3][2], 4);
+    ALIGN(uint16_t comp_bwd_ref[2][3][2], 4);
+    ALIGN(uint16_t comp_uni_ref[3][3][2], 4);
+    ALIGN(uint16_t txpart[7][3][2], 4);
+    ALIGN(uint16_t skip[3][2], 4);
+    ALIGN(uint16_t skip_mode[3][2], 4);
+    ALIGN(uint16_t seg_pred[3][2], 4);
+    ALIGN(uint16_t obmc[N_BS_SIZES][2], 4);
+    ALIGN(uint16_t pal_y[7][3][2], 4);
+    ALIGN(uint16_t pal_uv[2][2], 4);
+    ALIGN(uint16_t intrabc[2], 4);
 } CdfModeContext;
 
 typedef struct CdfCoefContext {
-    uint16_t skip[N_TX_SIZES][13][2];
-    uint16_t eob_bin_16[2][2][6];
-    uint16_t eob_bin_32[2][2][7 + 1];
-    uint16_t eob_bin_64[2][2][8];
-    uint16_t eob_bin_128[2][2][9];
-    uint16_t eob_bin_256[2][2][10 + 6];
-    uint16_t eob_bin_512[2][11 + 5];
-    uint16_t eob_bin_1024[2][12 + 4];
-    uint16_t eob_hi_bit[N_TX_SIZES][2][11 /*22*/][2];
-    uint16_t eob_base_tok[N_TX_SIZES][2][4][4];
-    uint16_t base_tok[N_TX_SIZES][2][41][5];
-    uint16_t dc_sign[2][3][2];
-    uint16_t br_tok[4 /*5*/][2][21][5];
+    ALIGN(uint16_t eob_bin_16[2][2][5 + 3], 16);
+    ALIGN(uint16_t eob_bin_32[2][2][6 + 2], 16);
+    ALIGN(uint16_t eob_bin_64[2][2][7 + 1], 16);
+    ALIGN(uint16_t eob_bin_128[2][2][8 + 0], 16);
+    ALIGN(uint16_t eob_bin_256[2][2][9 + 7], 32);
+    ALIGN(uint16_t eob_bin_512[2][10 + 6], 32);
+    ALIGN(uint16_t eob_bin_1024[2][11 + 5], 32);
+    ALIGN(uint16_t eob_base_tok[N_TX_SIZES][2][4][4], 8);
+    ALIGN(uint16_t base_tok[N_TX_SIZES][2][41][4], 8);
+    ALIGN(uint16_t br_tok[4 /*5*/][2][21][4], 8);
+    ALIGN(uint16_t eob_hi_bit[N_TX_SIZES][2][11 /*22*/][2], 4);
+    ALIGN(uint16_t skip[N_TX_SIZES][13][2], 4);
+    ALIGN(uint16_t dc_sign[2][3][2], 4);
 } CdfCoefContext;
 
 typedef struct CdfMvComponent {
-    uint16_t classes[11 + 1 + 4];
-    uint16_t class0[2];
-    uint16_t classN[10][2];
-    uint16_t class0_fp[2][4 + 1];
-    uint16_t classN_fp[4 + 1];
-    uint16_t class0_hp[2];
-    uint16_t classN_hp[2];
-    uint16_t sign[2];
+    ALIGN(uint16_t classes[11 + 5], 32);
+    ALIGN(uint16_t class0_fp[2][4], 8);
+    ALIGN(uint16_t classN_fp[4], 8);
+    ALIGN(uint16_t class0_hp[2], 4);
+    ALIGN(uint16_t classN_hp[2], 4);
+    ALIGN(uint16_t class0[2], 4);
+    ALIGN(uint16_t classN[10][2], 4);
+    ALIGN(uint16_t sign[2], 4);
 } CdfMvComponent;
 
 typedef struct CdfMvContext {
     CdfMvComponent comp[2];
-    uint16_t joint[N_MV_JOINTS + 1];
+    ALIGN(uint16_t joint[N_MV_JOINTS], 8);
 } CdfMvContext;
 
 typedef struct CdfContext {
     CdfModeContext m;
-    uint16_t kfym[5][5][N_INTRA_PRED_MODES + 1 + 2];
+    ALIGN(uint16_t kfym[5][5][N_INTRA_PRED_MODES + 3], 32);
     CdfCoefContext coef;
     CdfMvContext mv, dmv;
 } CdfContext;
--- a/src/decode.c
+++ b/src/decode.c
@@ -81,7 +81,7 @@
     const int have_hp = f->frame_hdr->hp;
     const int sign = dav1d_msac_decode_bool_adapt(&ts->msac, mv_comp->sign);
     const int cl = dav1d_msac_decode_symbol_adapt16(&ts->msac,
-                                                    mv_comp->classes, 11);
+                                                    mv_comp->classes, 10);
     int up, fp, hp;
 
     if (!cl) {
@@ -88,7 +88,7 @@
         up = dav1d_msac_decode_bool_adapt(&ts->msac, mv_comp->class0);
         if (have_fp) {
             fp = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                 mv_comp->class0_fp[up], 4);
+                                                 mv_comp->class0_fp[up], 3);
             hp = have_hp ? dav1d_msac_decode_bool_adapt(&ts->msac,
                                                         mv_comp->class0_hp) : 1;
         } else {
@@ -102,7 +102,7 @@
                                                mv_comp->classN[n]) << n;
         if (have_fp) {
             fp = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                 mv_comp->classN_fp, 4);
+                                                 mv_comp->classN_fp, 3);
             hp = have_hp ? dav1d_msac_decode_bool_adapt(&ts->msac,
                                                         mv_comp->classN_hp) : 1;
         } else {
@@ -120,7 +120,7 @@
                              CdfMvContext *const mv_cdf, const int have_fp)
 {
     switch (dav1d_msac_decode_symbol_adapt4(&t->ts->msac, t->ts->cdf.mv.joint,
-                                            N_MV_JOINTS))
+                                            N_MV_JOINTS - 1))
     {
     case MV_JOINT_HV:
         ref_mv->y += read_mv_component_diff(t, &mv_cdf->comp[0], have_fp);
@@ -380,7 +380,7 @@
     Dav1dTileState *const ts = t->ts;
     const Dav1dFrameContext *const f = t->f;
     const int pal_sz = b->pal_sz[pl] = dav1d_msac_decode_symbol_adapt8(&ts->msac,
-                                           ts->cdf.m.pal_sz[pl][sz_ctx], 7) + 2;
+                                           ts->cdf.m.pal_sz[pl][sz_ctx], 6) + 2;
     uint16_t cache[16], used_cache[8];
     int l_cache = pl ? t->pal_sz_uv[1][by4] : t->l.pal_sz[by4];
     int n_cache = 0;
@@ -586,7 +586,7 @@
     Dav1dTileState *const ts = t->ts;
     const ptrdiff_t stride = bw4 * 4;
     pal_idx[0] = dav1d_msac_decode_uniform(&ts->msac, b->pal_sz[pl]);
-    uint16_t (*const color_map_cdf)[8 + 1] =
+    uint16_t (*const color_map_cdf)[8] =
         ts->cdf.m.color_map[pl][b->pal_sz[pl] - 2];
     uint8_t (*const order)[8] = t->scratch.pal_order;
     uint8_t *const ctx = t->scratch.pal_ctx;
@@ -597,7 +597,7 @@
         order_palette(pal_idx, stride, i, first, last, order, ctx);
         for (int j = first, m = 0; j >= last; j--, m++) {
             const int color_idx = dav1d_msac_decode_symbol_adapt8(&ts->msac,
-                                      color_map_cdf[ctx[m]], b->pal_sz[pl]);
+                                      color_map_cdf[ctx[m]], b->pal_sz[pl] - 1);
             pal_idx[(i - j) * stride + j] = order[m][color_idx];
         }
     }
@@ -813,7 +813,7 @@
                                         &seg_ctx, f->cur_segmap, f->b4_stride);
                 const unsigned diff = dav1d_msac_decode_symbol_adapt8(&ts->msac,
                                           ts->cdf.m.seg_id[seg_ctx],
-                                          DAV1D_MAX_SEGMENTS);
+                                          DAV1D_MAX_SEGMENTS - 1);
                 const unsigned last_active_seg_id =
                     f->frame_hdr->segmentation.seg_data.last_active_segid;
                 b->seg_id = neg_deinterleave(diff, pred_seg_id,
@@ -885,7 +885,7 @@
             } else {
                 const unsigned diff = dav1d_msac_decode_symbol_adapt8(&ts->msac,
                                           ts->cdf.m.seg_id[seg_ctx],
-                                          DAV1D_MAX_SEGMENTS);
+                                          DAV1D_MAX_SEGMENTS - 1);
                 const unsigned last_active_seg_id =
                     f->frame_hdr->segmentation.seg_data.last_active_segid;
                 b->seg_id = neg_deinterleave(diff, pred_seg_id,
@@ -933,7 +933,7 @@
 
         if (have_delta_q) {
             int delta_q = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                          ts->cdf.m.delta_q, 4);
+                                                          ts->cdf.m.delta_q, 3);
             if (delta_q == 3) {
                 const int n_bits = 1 + dav1d_msac_decode_bools(&ts->msac, 3);
                 delta_q = dav1d_msac_decode_bools(&ts->msac, n_bits) +
@@ -954,7 +954,7 @@
 
                 for (int i = 0; i < n_lfs; i++) {
                     int delta_lf = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                        ts->cdf.m.delta_lf[i + f->frame_hdr->delta.lf.multi], 4);
+                        ts->cdf.m.delta_lf[i + f->frame_hdr->delta.lf.multi], 3);
                     if (delta_lf == 3) {
                         const int n_bits = 1 + dav1d_msac_decode_bools(&ts->msac, 3);
                         delta_lf = dav1d_msac_decode_bools(&ts->msac, n_bits) +
@@ -1019,7 +1019,7 @@
             ts->cdf.kfym[dav1d_intra_mode_context[t->a->mode[bx4]]]
                         [dav1d_intra_mode_context[t->l.mode[by4]]];
         b->y_mode = dav1d_msac_decode_symbol_adapt16(&ts->msac, ymode_cdf,
-                                                     N_INTRA_PRED_MODES);
+                                                     N_INTRA_PRED_MODES - 1);
         if (DEBUG_BLOCK_INFO)
             printf("Post-ymode[%d]: r=%d\n", b->y_mode, ts->msac.rng);
 
@@ -1028,7 +1028,7 @@
             b->y_mode <= VERT_LEFT_PRED)
         {
             uint16_t *const acdf = ts->cdf.m.angle_delta[b->y_mode - VERT_PRED];
-            const int angle = dav1d_msac_decode_symbol_adapt8(&ts->msac, acdf, 7);
+            const int angle = dav1d_msac_decode_symbol_adapt8(&ts->msac, acdf, 6);
             b->y_angle = angle - 3;
         } else {
             b->y_angle = 0;
@@ -1039,7 +1039,7 @@
                 cbw4 == 1 && cbh4 == 1 : !!(cfl_allowed_mask & (1 << bs));
             uint16_t *const uvmode_cdf = ts->cdf.m.uv_mode[cfl_allowed][b->y_mode];
             b->uv_mode = dav1d_msac_decode_symbol_adapt16(&ts->msac, uvmode_cdf,
-                             N_UV_INTRA_PRED_MODES - !cfl_allowed);
+                             N_UV_INTRA_PRED_MODES - 1 - !cfl_allowed);
             if (DEBUG_BLOCK_INFO)
                 printf("Post-uvmode[%d]: r=%d\n", b->uv_mode, ts->msac.rng);
 
@@ -1046,13 +1046,13 @@
             if (b->uv_mode == CFL_PRED) {
 #define SIGN(a) (!!(a) + ((a) > 0))
                 const int sign = dav1d_msac_decode_symbol_adapt8(&ts->msac,
-                                     ts->cdf.m.cfl_sign, 8) + 1;
+                                     ts->cdf.m.cfl_sign, 7) + 1;
                 const int sign_u = sign * 0x56 >> 8, sign_v = sign - sign_u * 3;
                 assert(sign_u == sign / 3);
                 if (sign_u) {
                     const int ctx = (sign_u == 2) * 3 + sign_v;
                     b->cfl_alpha[0] = dav1d_msac_decode_symbol_adapt16(&ts->msac,
-                                          ts->cdf.m.cfl_alpha[ctx], 16) + 1;
+                                          ts->cdf.m.cfl_alpha[ctx], 15) + 1;
                     if (sign_u == 1) b->cfl_alpha[0] = -b->cfl_alpha[0];
                 } else {
                     b->cfl_alpha[0] = 0;
@@ -1060,7 +1060,7 @@
                 if (sign_v) {
                     const int ctx = (sign_v == 2) * 3 + sign_u;
                     b->cfl_alpha[1] = dav1d_msac_decode_symbol_adapt16(&ts->msac,
-                                          ts->cdf.m.cfl_alpha[ctx], 16) + 1;
+                                          ts->cdf.m.cfl_alpha[ctx], 15) + 1;
                     if (sign_v == 1) b->cfl_alpha[1] = -b->cfl_alpha[1];
                 } else {
                     b->cfl_alpha[1] = 0;
@@ -1073,7 +1073,7 @@
                        b->uv_mode <= VERT_LEFT_PRED)
             {
                 uint16_t *const acdf = ts->cdf.m.angle_delta[b->uv_mode - VERT_PRED];
-                const int angle = dav1d_msac_decode_symbol_adapt8(&ts->msac, acdf, 7);
+                const int angle = dav1d_msac_decode_symbol_adapt8(&ts->msac, acdf, 6);
                 b->uv_angle = angle - 3;
             } else {
                 b->uv_angle = 0;
@@ -1114,7 +1114,7 @@
             if (is_filter) {
                 b->y_mode = FILTER_PRED;
                 b->y_angle = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                 ts->cdf.m.filter_intra, 5);
+                                 ts->cdf.m.filter_intra, 4);
             }
             if (DEBUG_BLOCK_INFO)
                 printf("Post-filterintramode[%d/%d]: r=%d\n",
@@ -1157,7 +1157,7 @@
                 const int tctx = get_tx_ctx(t->a, &t->l, t_dim, by4, bx4);
                 uint16_t *const tx_cdf = ts->cdf.m.txsz[t_dim->max - 1][tctx];
                 int depth = dav1d_msac_decode_symbol_adapt4(&ts->msac, tx_cdf,
-                                imin(t_dim->max + 1, 3));
+                                imin(t_dim->max, 2));
 
                 while (depth--) {
                     b->tx = t_dim->sub;
@@ -1479,7 +1479,7 @@
 
             b->inter_mode = dav1d_msac_decode_symbol_adapt8(&ts->msac,
                                 ts->cdf.m.comp_inter_mode[ctx],
-                                N_COMP_INTER_PRED_MODES);
+                                N_COMP_INTER_PRED_MODES - 1);
             if (DEBUG_BLOCK_INFO)
                 printf("Post-compintermode[%d,ctx=%d,n_mvs=%d]: r=%d\n",
                        b->inter_mode, ctx, n_mvs, ts->msac.rng);
@@ -1587,7 +1587,7 @@
                                        ts->cdf.m.wedge_comp[ctx]);
                     if (b->comp_type == COMP_INTER_WEDGE)
                         b->wedge_idx = dav1d_msac_decode_symbol_adapt16(&ts->msac,
-                                           ts->cdf.m.wedge_idx[ctx], 16);
+                                           ts->cdf.m.wedge_idx[ctx], 15);
                 } else {
                     b->comp_type = COMP_INTER_SEG;
                 }
@@ -1742,7 +1742,7 @@
             {
                 b->interintra_mode = dav1d_msac_decode_symbol_adapt4(&ts->msac,
                                          ts->cdf.m.interintra_mode[ii_sz_grp],
-                                         N_INTER_INTRA_PRED_MODES);
+                                         N_INTER_INTRA_PRED_MODES - 1);
                 const int wedge_ctx = dav1d_wedge_ctx_lut[bs];
                 b->interintra_type = INTER_INTRA_BLEND +
                                      dav1d_msac_decode_bool_adapt(&ts->msac,
@@ -1749,7 +1749,7 @@
                                          ts->cdf.m.interintra_wedge[wedge_ctx]);
                 if (b->interintra_type == INTER_INTRA_WEDGE)
                     b->wedge_idx = dav1d_msac_decode_symbol_adapt16(&ts->msac,
-                                       ts->cdf.m.wedge_idx[wedge_ctx], 16);
+                                       ts->cdf.m.wedge_idx[wedge_ctx], 15);
             } else {
                 b->interintra_type = INTER_INTRA_NONE;
             }
@@ -1782,7 +1782,7 @@
 
                 b->motion_mode = allow_warp ?
                     dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                        ts->cdf.m.motion_mode[bs], 3) :
+                        ts->cdf.m.motion_mode[bs], 2) :
                     dav1d_msac_decode_bool_adapt(&ts->msac, ts->cdf.m.obmc[bs]);
                 if (b->motion_mode == MM_WARP) {
                     has_subpel_filter = 0;
@@ -1822,7 +1822,7 @@
                                                 by4, bx4);
                 filter[0] = dav1d_msac_decode_symbol_adapt4(&ts->msac,
                                ts->cdf.m.filter[0][ctx1],
-                               DAV1D_N_SWITCHABLE_FILTERS);
+                               DAV1D_N_SWITCHABLE_FILTERS - 1);
                 if (f->seq_hdr->dual_filter) {
                     const int ctx2 = get_filter_ctx(t->a, &t->l, comp, 1,
                                                     b->ref[0], by4, bx4);
@@ -1831,7 +1831,7 @@
                                filter[0], ctx1, ts->msac.rng);
                     filter[1] = dav1d_msac_decode_symbol_adapt4(&ts->msac,
                                     ts->cdf.m.filter[1][ctx2],
-                                    DAV1D_N_SWITCHABLE_FILTERS);
+                                    DAV1D_N_SWITCHABLE_FILTERS - 1);
                     if (DEBUG_BLOCK_INFO)
                         printf("Post-subpel_filter2[%d,ctx=%d]: r=%d\n",
                                filter[1], ctx2, ts->msac.rng);
@@ -2022,9 +2022,8 @@
             const Av1Block *const b = &f->frame_thread.b[t->by * f->b4_stride + t->bx];
             bp = b->bl == bl ? b->bp : PARTITION_SPLIT;
         } else {
-            const unsigned n_part = bl == BL_8X8 ? N_SUB8X8_PARTITIONS :
-                bl == BL_128X128 ? N_PARTITIONS - 2 : N_PARTITIONS;
-            bp = dav1d_msac_decode_symbol_adapt16(&t->ts->msac, pc, n_part);
+            bp = dav1d_msac_decode_symbol_adapt16(&t->ts->msac, pc,
+                                                  dav1d_partition_type_count[bl]);
             if (f->cur.p.layout == DAV1D_PIXEL_LAYOUT_I422 &&
                 (bp == PARTITION_V || bp == PARTITION_V4 ||
                  bp == PARTITION_T_LEFT_SPLIT || bp == PARTITION_T_RIGHT_SPLIT))
@@ -2380,7 +2379,7 @@
 
     if (frame_type == DAV1D_RESTORATION_SWITCHABLE) {
         const int filter = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                               ts->cdf.m.restore_switchable, 3);
+                               ts->cdf.m.restore_switchable, 2);
         lr->type = filter ? filter == 2 ? DAV1D_RESTORATION_SGRPROJ :
                                           DAV1D_RESTORATION_WIENER :
                                           DAV1D_RESTORATION_NONE;
@@ -2636,9 +2635,13 @@
                 goto error;
             }
         }
+        Dav1dTileState *ts_new = dav1d_alloc_aligned(sizeof(*f->ts) * n_ts, 32);
+        if (!ts_new) goto error;
         if (n_ts > f->n_ts) {
-            Dav1dTileState *ts_new = realloc(f->ts, sizeof(*f->ts) * n_ts);
-            if (!ts_new) goto error;
+            if (f->ts) {
+                memcpy(ts_new, f->ts, sizeof(*f->ts) * f->n_ts);
+                dav1d_free_aligned(f->ts);
+            }
             f->ts = ts_new;
             for (int n = f->n_ts; n < n_ts; f->n_ts = ++n) {
                 Dav1dTileState *const ts = &f->ts[n];
@@ -2654,9 +2657,9 @@
                 pthread_cond_destroy(&ts->tile_thread.cond);
                 pthread_mutex_destroy(&ts->tile_thread.lock);
             }
+            memcpy(ts_new, f->ts, sizeof(*f->ts) * n_ts);
+            dav1d_free_aligned(f->ts);
             f->n_ts = n_ts;
-            Dav1dTileState *ts_new = realloc(f->ts, sizeof(*f->ts) * n_ts);
-            if (!ts_new) goto error;
             f->ts = ts_new;
         }
     }
--- a/src/internal.h
+++ b/src/internal.h
@@ -241,13 +241,13 @@
 };
 
 struct Dav1dTileState {
+    CdfContext cdf;
+    MsacContext msac;
+
     struct {
         int col_start, col_end, row_start, row_end; // in 4px units
         int col, row; // in tile units
     } tiling;
-
-    CdfContext cdf;
-    MsacContext msac;
 
     atomic_int progress; // in sby units, TILE_ERROR after a decoding error
     struct {
--- a/src/lib.c
+++ b/src/lib.c
@@ -502,7 +502,7 @@
             pthread_cond_destroy(&ts->tile_thread.cond);
             pthread_mutex_destroy(&ts->tile_thread.lock);
         }
-        free(f->ts);
+        dav1d_free_aligned(f->ts);
         dav1d_free_aligned(f->tc);
         dav1d_free_aligned(f->ipred_edge[0]);
         free(f->a);
--- a/src/msac.c
+++ b/src/msac.c
@@ -116,42 +116,39 @@
 
 /* Decodes a symbol given an inverse cumulative distribution function (CDF)
  * table in Q15. */
-static unsigned decode_symbol(MsacContext *const s, const uint16_t *const cdf,
-                              const size_t n_symbols)
+unsigned dav1d_msac_decode_symbol_adapt_c(MsacContext *const s,
+                                          uint16_t *const cdf,
+                                          const size_t n_symbols)
 {
-    const unsigned c = s->dif >> (EC_WIN_SIZE - 16);
-    unsigned u, v = s->rng, r = s->rng >> 8, ret = 0;
+    const unsigned c = s->dif >> (EC_WIN_SIZE - 16), r = s->rng >> 8;
+    unsigned u, v = s->rng, val = -1;
 
-    assert(!cdf[n_symbols - 1]);
+    assert(n_symbols <= 15);
+    assert(cdf[n_symbols] <= 32);
 
     do {
+        val++;
         u = v;
-        v = r * (cdf[ret++] >> EC_PROB_SHIFT);
+        v = r * (cdf[val] >> EC_PROB_SHIFT);
         v >>= 7 - EC_PROB_SHIFT;
-        v += EC_MIN_PROB * (int) (n_symbols - ret);
+        v += EC_MIN_PROB * ((unsigned)n_symbols - val);
     } while (c < v);
 
     assert(u <= s->rng);
 
     ctx_norm(s, s->dif - ((ec_win)v << (EC_WIN_SIZE - 16)), u - v);
-    return ret - 1;
-}
 
-unsigned dav1d_msac_decode_symbol_adapt_c(MsacContext *const s,
-                                          uint16_t *const cdf,
-                                          const size_t n_symbols)
-{
-    const unsigned val = decode_symbol(s, cdf, n_symbols);
     if (s->allow_update_cdf) {
         const unsigned count = cdf[n_symbols];
-        const int rate = ((count >> 4) | 4) + (n_symbols > 3);
+        const unsigned rate = 4 + (count >> 4) + (n_symbols > 2);
         unsigned i;
         for (i = 0; i < val; i++)
             cdf[i] += (32768 - cdf[i]) >> rate;
-        for (; i < n_symbols - 1; i++)
+        for (; i < n_symbols; i++)
             cdf[i] -= cdf[i] >> rate;
         cdf[n_symbols] = count + (count < 32);
     }
+
     return val;
 }
 
@@ -163,7 +160,7 @@
     if (s->allow_update_cdf) {
         // update_cdf() specialized for boolean CDFs
         const unsigned count = cdf[1];
-        const int rate = (count >> 4) | 4;
+        const int rate = 4 + (count >> 4);
         if (bit)
             cdf[0] += (32768 - cdf[0]) >> rate;
         else
--- a/src/msac.h
+++ b/src/msac.h
@@ -60,7 +60,7 @@
 unsigned dav1d_msac_decode_bool_c(MsacContext *s, unsigned f);
 int dav1d_msac_decode_subexp(MsacContext *s, int ref, int n, unsigned k);
 
-/* Supported n_symbols ranges: adapt4: 1-5, adapt8: 1-8, adapt16: 4-16 */
+/* Supported n_symbols ranges: adapt4: 1-4, adapt8: 1-7, adapt16: 3-15 */
 #ifndef dav1d_msac_decode_symbol_adapt4
 #define dav1d_msac_decode_symbol_adapt4  dav1d_msac_decode_symbol_adapt_c
 #endif
--- a/src/recon_tmpl.c
+++ b/src/recon_tmpl.c
@@ -104,11 +104,11 @@
                 dav1d_filter_mode_to_y_mode[b->y_angle] : b->y_mode;
             if (f->frame_hdr->reduced_txtp_set || t_dim->min == TX_16X16) {
                 idx = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                          ts->cdf.m.txtp_intra2[t_dim->min][y_mode_nofilt], 5);
+                          ts->cdf.m.txtp_intra2[t_dim->min][y_mode_nofilt], 4);
                 *txtp = dav1d_tx_types_per_set[idx + 0];
             } else {
                 idx = dav1d_msac_decode_symbol_adapt8(&ts->msac,
-                          ts->cdf.m.txtp_intra1[t_dim->min][y_mode_nofilt], 7);
+                          ts->cdf.m.txtp_intra1[t_dim->min][y_mode_nofilt], 6);
                 *txtp = dav1d_tx_types_per_set[idx + 5];
             }
             if (dbg)
@@ -121,11 +121,11 @@
                 *txtp = (idx - 1) & IDTX; /* idx ? DCT_DCT : IDTX */
             } else if (t_dim->min == TX_16X16) {
                 idx = dav1d_msac_decode_symbol_adapt16(&ts->msac,
-                          ts->cdf.m.txtp_inter2, 12);
+                          ts->cdf.m.txtp_inter2, 11);
                 *txtp = dav1d_tx_types_per_set[idx + 12];
             } else {
                 idx = dav1d_msac_decode_symbol_adapt16(&ts->msac,
-                          ts->cdf.m.txtp_inter1[t_dim->min], 16);
+                          ts->cdf.m.txtp_inter1[t_dim->min], 15);
                 *txtp = dav1d_tx_types_per_set[idx + 24];
             }
             if (dbg)
@@ -143,7 +143,7 @@
 #define case_sz(sz, bin, ns, is_1d) \
     case sz: { \
         uint16_t *const eob_bin_cdf = ts->cdf.coef.eob_bin_##bin[chroma]is_1d; \
-        eob_bin = dav1d_msac_decode_symbol_adapt##ns(&ts->msac, eob_bin_cdf, 5 + sz); \
+        eob_bin = dav1d_msac_decode_symbol_adapt##ns(&ts->msac, eob_bin_cdf, 4 + sz); \
         break; \
     }
     case_sz(0,   16,  4, [is_1d]);
@@ -175,7 +175,7 @@
     }
 
     // base tokens
-    uint16_t (*const br_cdf)[5] =
+    uint16_t (*const br_cdf)[4] =
         ts->cdf.coef.br_tok[imin(t_dim->ctx, 3)][chroma];
     const int16_t *const scan = dav1d_scans[tx][tx_class];
     int dc_tok;
@@ -193,7 +193,7 @@
             const int ctx = 1 + (eob > sw * sh * 2) + (eob > sw * sh * 4);
             uint16_t *const lo_cdf = ts->cdf.coef.eob_base_tok[t_dim->ctx][chroma][ctx];
 
-            int tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf, 3);
+            int tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf, 2);
             int tok = 1 + tok_br;
             if (dbg)
                 printf("Post-lo_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
@@ -209,19 +209,19 @@
                 const int br_ctx = get_br_ctx(levels, 1, tx_class, x, y, stride);
 
                 tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                            br_cdf[br_ctx], 4);
+                            br_cdf[br_ctx], 3);
                 tok = 3 + tok_br;
                 dbg_print_hi_tok(eob, tok + tok_br, tok_br);
 
                 if (tok_br == 3) {
                     tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                             br_cdf[br_ctx], 4);
+                                                             br_cdf[br_ctx], 3);
                     tok = 6 + tok_br;
                     dbg_print_hi_tok(eob, tok + tok_br, tok_br);
                     if (tok_br == 3) {
                         tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
                                                                  br_cdf[br_ctx],
-                                                                 4);
+                                                                 3);
                         tok = 9 + tok_br;
                         dbg_print_hi_tok(eob, tok + tok_br, tok_br);
                         if (tok_br == 3) {
@@ -228,7 +228,7 @@
                             tok = 12 +
                                 dav1d_msac_decode_symbol_adapt4(&ts->msac,
                                                                 br_cdf[br_ctx],
-                                                                4);
+                                                                3);
                             dbg_print_hi_tok(eob, tok + tok_br, tok_br);
                         }
                     }
@@ -244,7 +244,7 @@
             // lo tok
             const int ctx = get_coef_nz_ctx(levels, tx, tx_class, x, y, stride);
             uint16_t *const lo_cdf = ts->cdf.coef.base_tok[t_dim->ctx][chroma][ctx];
-            int tok = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf, 4);
+            int tok = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf, 3);
             if (dbg)
                 printf("Post-lo_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
                        t_dim->ctx, chroma, ctx, i, rc, tok, ts->msac.rng);
@@ -254,13 +254,13 @@
                 const int br_ctx = get_br_ctx(levels, 1, tx_class, x, y, stride);
 
                 int tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                             br_cdf[br_ctx], 4);
+                                                             br_cdf[br_ctx], 3);
                 tok = 3 + tok_br;
                 dbg_print_hi_tok(i, tok + tok_br, tok_br);
 
                 if (tok_br == 3) {
                     tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                             br_cdf[br_ctx], 4);
+                                                             br_cdf[br_ctx], 3);
 
                     tok = 6 + tok_br;
                     dbg_print_hi_tok(i, tok + tok_br, tok_br);
@@ -267,13 +267,13 @@
                     if (tok_br == 3) {
                         tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
                                                                  br_cdf[br_ctx],
-                                                                 4);
+                                                                 3);
                         tok = 9 + tok_br;
                         dbg_print_hi_tok(i, tok + tok_br, tok_br);
                         if (tok_br == 3) {
                             tok = 12 + dav1d_msac_decode_symbol_adapt4(&ts->msac,
                                                                        br_cdf[br_ctx],
-                                                                       4);
+                                                                       3);
                             dbg_print_hi_tok(i, tok + tok_br, tok_br);
                         }
                     }
@@ -287,7 +287,7 @@
             const int ctx = (tx_class != TX_CLASS_2D) ?
                 get_coef_nz_ctx(levels, tx, tx_class, 0, 0, stride) : 0;
             uint16_t *const lo_cdf = ts->cdf.coef.base_tok[t_dim->ctx][chroma][ctx];
-            dc_tok = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf, 4);
+            dc_tok = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf, 3);
             if (dbg)
                 printf("Post-dc_lo_tok[%d][%d][%d][%d]: r=%d\n",
                        t_dim->ctx, chroma, ctx, dc_tok, ts->msac.rng);
@@ -302,7 +302,7 @@
                 const int br_ctx = get_br_ctx(levels, 0, tx_class, 0, 0, stride);
 
                 int tok_br =
-                    dav1d_msac_decode_symbol_adapt4(&ts->msac, br_cdf[br_ctx], 4);
+                    dav1d_msac_decode_symbol_adapt4(&ts->msac, br_cdf[br_ctx], 3);
                 dc_tok = 3 + tok_br;
 
                 dbg_print_hi_tok(dc_tok + tok_br, tok_br);
@@ -309,13 +309,13 @@
 
                 if (tok_br == 3) {
                     tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                             br_cdf[br_ctx], 4);
+                                                             br_cdf[br_ctx], 3);
                     dc_tok = 6 + tok_br;
                     dbg_print_hi_tok(dc_tok + tok_br, tok_br);
                     if (tok_br == 3) {
                         tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
                                                                  br_cdf[br_ctx],
-                                                                 4);
+                                                                 3);
                         dc_tok = 9 + tok_br;
                         dbg_print_hi_tok(dc_tok + tok_br, tok_br);
                         if (tok_br == 3) {
@@ -322,7 +322,7 @@
                             dc_tok = 12 +
                                 dav1d_msac_decode_symbol_adapt4(&ts->msac,
                                                                 br_cdf[br_ctx],
-                                                                4);
+                                                                3);
                             dbg_print_hi_tok(dc_tok + tok_br, tok_br);
                         }
                     }
@@ -332,7 +332,7 @@
         }
     } else { // dc-only
         uint16_t *const lo_cdf = ts->cdf.coef.eob_base_tok[t_dim->ctx][chroma][0];
-        int tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf, 3);
+        int tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf, 2);
         dc_tok = 1 + tok_br;
         if (dbg)
             printf("Post-dc_lo_tok[%d][%d][%d][%d]: r=%d\n",
@@ -345,24 +345,24 @@
         printf("Post-dc_hi_tok[%d][%d][0][%d->%d]: r=%d\n", \
                imin(t_dim->ctx, 3), chroma, tok_br, dc_tok, ts->msac.rng);
 
-            tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, br_cdf[0], 4);
+            tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, br_cdf[0], 3);
             dc_tok = 3 + tok_br;
 
             dbg_print_hi_tok(dc_tok + tok_br, tok_br);
 
             if (tok_br == 3) {
-                tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, br_cdf[0], 4);
+                tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, br_cdf[0], 3);
                 dc_tok = 6 + tok_br;
                 dbg_print_hi_tok(dc_tok + tok_br, tok_br);
                 if (tok_br == 3) {
                     tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                             br_cdf[0], 4);
+                                                             br_cdf[0], 3);
                     dc_tok = 9 + tok_br;
                     dbg_print_hi_tok(dc_tok + tok_br, tok_br);
                     if (tok_br == 3) {
                         dc_tok = 12 +
                             dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                            br_cdf[0], 4);
+                                                            br_cdf[0], 3);
                         dbg_print_hi_tok(dc_tok + tok_br, tok_br);
                     }
                 }
--- a/src/tables.c
+++ b/src/tables.c
@@ -225,6 +225,14 @@
     [NEARMV_NEWMV]        = { NEARMV,    NEWMV     },
 };
 
+const uint8_t dav1d_partition_type_count[N_BL_LEVELS] = {
+    [BL_128X128] = N_PARTITIONS - 3,
+    [BL_64X64]   = N_PARTITIONS - 1,
+    [BL_32X32]   = N_PARTITIONS - 1,
+    [BL_16X16]   = N_PARTITIONS - 1,
+    [BL_8X8]     = N_SUB8X8_PARTITIONS - 1,
+};
+
 const uint8_t /* enum TxfmType */ dav1d_tx_types_per_set[40] = {
     /* Intra2 */
     IDTX, DCT_DCT, ADST_ADST, ADST_DCT, DCT_ADST,
--- a/src/tables.h
+++ b/src/tables.h
@@ -52,6 +52,7 @@
 extern const uint8_t /* enum InterPredMode */
                      dav1d_comp_inter_pred_modes[N_COMP_INTER_PRED_MODES][2];
 
+extern const uint8_t dav1d_partition_type_count[N_BL_LEVELS];
 extern const uint8_t /* enum TxfmType */ dav1d_tx_types_per_set[40];
 
 extern const uint8_t dav1d_filter_mode_to_y_mode[5];
--- a/src/x86/msac.asm
+++ b/src/x86/msac.asm
@@ -88,7 +88,7 @@
     movp           m3, [t0+msac.dif]
     mov           t3d, [t0+msac.update_cdf]
     mov           t4d, t2d
-    neg            t2
+    not            t2     ; -(n_symbols + 1)
     pshuflw        m2, m2, q0000
     movd     [buf+12], m2
     pand           m2, [rax]
@@ -112,8 +112,8 @@
     pcmpeqw        m2, m2
     mov           t2d, t3d
     shr           t3d, 4
-    cmp           t4d, 4
-    sbb           t3d, -5 ; (count >> 4) + (n_symbols > 3) + 4
+    cmp           t4d, 3
+    sbb           t3d, -5 ; (count >> 4) + (n_symbols > 2) + 4
     cmp           t2d, 32
     adc           t2d, 0  ; count + (count < 32)
     movd           m3, t3d
@@ -120,7 +120,7 @@
     pavgw          m2, m1 ; i >= val ? -1 : 32768
     psubw          m2, m0 ; for (i = 0; i < val; i++)
     psubw          m0, m1 ;     cdf[i] += (32768 - cdf[i]) >> rate;
-    psraw          m2, m3 ; for (; i < n_symbols - 1; i++)
+    psraw          m2, m3 ; for (; i < n_symbols; i++)
     paddw          m0, m2 ;     cdf[i] += ((  -1 - cdf[i]) >> rate) + 1;
     movq         [t1], m0
     mov     [t1+t4*2], t2w
@@ -214,11 +214,11 @@
     DECODE_SYMBOL_ADAPT_INIT
     LEA           rax, pw_0xff00
     movd           m2, [t0+msac.rng]
-    movu           m1, [t1]
+    mova           m1, [t1]
     movp           m3, [t0+msac.dif]
     mov           t3d, [t0+msac.update_cdf]
     mov           t4d, t2d
-    neg            t2
+    not            t2
     pshuflw        m2, m2, q0000
     movd     [buf+12], m2
     punpcklqdq     m2, m2
@@ -242,7 +242,7 @@
     pcmpeqw        m2, m2
     mov           t2d, t3d
     shr           t3d, 4
-    cmp           t4d, 4 ; may be called with n_symbols < 4
+    cmp           t4d, 3 ; may be called with n_symbols <= 2
     sbb           t3d, -5
     cmp           t2d, 32
     adc           t2d, 0
@@ -252,7 +252,7 @@
     psubw          m0, m1
     psraw          m2, m3
     paddw          m0, m2
-    movu         [t1], m0
+    mova         [t1], m0
     mov     [t1+t4*2], t2w
     jmp m(msac_decode_symbol_adapt4).renorm
 
@@ -260,12 +260,12 @@
     DECODE_SYMBOL_ADAPT_INIT
     LEA           rax, pw_0xff00
     movd           m4, [t0+msac.rng]
-    movu           m2, [t1]
-    movu           m3, [t1+16]
+    mova           m2, [t1]
+    mova           m3, [t1+16]
     movp           m5, [t0+msac.dif]
     mov           t3d, [t0+msac.update_cdf]
     mov           t4d, t2d
-    neg            t2
+    not            t2
 %if WIN64
     sub           rsp, 48 ; need 36 bytes, shadow space is only 32
 %endif
@@ -288,8 +288,8 @@
     punpcklqdq     m5, m5
     paddw          m3, m4
     mova        [buf], m2
-    mova     [buf+16], m3
     psubusw        m2, m5
+    mova     [buf+16], m3
     psubusw        m3, m5
     pxor           m4, m4
     pcmpeqw        m2, m4
@@ -301,7 +301,7 @@
     movzx         t3d, word [t1+t4*2]
     pcmpeqw        m4, m4
     mova           m5, m4
-    lea           t2d, [t3+80] ; only support n_symbols >= 4
+    lea           t2d, [t3+80] ; only support n_symbols > 2
     shr           t2d, 4
     cmp           t3d, 32
     adc           t3d, 0
@@ -316,8 +316,8 @@
     psraw          m5, m2
     paddw          m0, m4
     paddw          m1, m5
-    movu         [t1], m0
-    movu      [t1+16], m1
+    mova         [t1], m0
+    mova      [t1+16], m1
     mov     [t1+t4*2], t3w
 .renorm:
     tzcnt         eax, eax
--- a/tests/checkasm/msac.c
+++ b/tests/checkasm/msac.c
@@ -51,12 +51,14 @@
     decode_bool_fn         bool;
 } MsacDSPContext;
 
-static void randomize_cdf(uint16_t *const cdf, int n) {
-    for (int i = 16; i > n; i--)
-        cdf[i] = rnd(); /* randomize padding */
-    cdf[n] = cdf[n-1] = 0;
-    while (--n > 0)
-        cdf[n-1] = cdf[n] + rnd() % (32768 - cdf[n] - n) + 1;
+static void randomize_cdf(uint16_t *const cdf, const int n) {
+    int i;
+    for (i = 15; i > n; i--)
+        cdf[i] = rnd(); // padding
+    cdf[i] = 0;         // count
+    do {
+        cdf[i - 1] = cdf[i] + rnd() % (32768 - cdf[i] - i) + 1;
+    } while (--i > 0);
 }
 
 /* memcmp() on structs can have weird behavior due to padding etc. */
@@ -69,7 +71,7 @@
 static void msac_dump(unsigned c_res, unsigned a_res,
                       const MsacContext *const a, const MsacContext *const b,
                       const uint16_t *const cdf_a, const uint16_t *const cdf_b,
-                      int num_cdf)
+                      const int num_cdf)
 {
     if (c_res != a_res)
         fprintf(stderr, "c_res %u a_res %u\n", c_res, a_res);
@@ -86,16 +88,15 @@
     if (a->allow_update_cdf)
         fprintf(stderr, "allow_update_cdf %d vs %d\n",
                 a->allow_update_cdf, b->allow_update_cdf);
-    if (cdf_a != NULL && cdf_b != NULL &&
-        memcmp(cdf_a, cdf_b, sizeof(*cdf_a) * num_cdf)) {
+    if (num_cdf && memcmp(cdf_a, cdf_b, sizeof(*cdf_a) * (num_cdf + 1))) {
         fprintf(stderr, "cdf:\n");
-        for (int i = 0; i < num_cdf; i++)
+        for (int i = 0; i <= num_cdf; i++)
             fprintf(stderr, " %5u", cdf_a[i]);
         fprintf(stderr, "\n");
-        for (int i = 0; i < num_cdf; i++)
+        for (int i = 0; i <= num_cdf; i++)
             fprintf(stderr, " %5u", cdf_b[i]);
         fprintf(stderr, "\n");
-        for (int i = 0; i < num_cdf; i++)
+        for (int i = 0; i <= num_cdf; i++)
             fprintf(stderr, "     %c", cdf_a[i] != cdf_b[i] ? 'x' : '.');
         fprintf(stderr, "\n");
     }
@@ -105,7 +106,7 @@
     if (check_func(c->symbol_adapt##n, "msac_decode_symbol_adapt%d", n)) { \
         for (int cdf_update = 0; cdf_update <= 1; cdf_update++) {          \
             for (int ns = n_min; ns <= n_max; ns++) {                      \
-                dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update);      \
+                dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update);         \
                 s_a = s_c;                                                 \
                 randomize_cdf(cdf[0], ns);                                 \
                 memcpy(cdf[1], cdf[0], sizeof(*cdf));                      \
@@ -117,11 +118,11 @@
                     {                                                      \
                         if (fail())                                        \
                             msac_dump(c_res, a_res, &s_c, &s_a,            \
-                                      cdf[0], cdf[1], ns + 1);             \
+                                      cdf[0], cdf[1], ns);                 \
                     }                                                      \
                 }                                                          \
-                if (cdf_update && ns == n)                                 \
-                    bench_new(&s_a, cdf[0], n);                            \
+                if (cdf_update && ns == n - 1)                             \
+                    bench_new(&s_a, cdf[1], ns);                           \
             }                                                              \
         }                                                                  \
     }                                                                      \
@@ -128,15 +129,13 @@
 } while (0)
 
 static void check_decode_symbol(MsacDSPContext *const c, uint8_t *const buf) {
-    /* Use an aligned CDF buffer for more consistent benchmark
-     * results, and a misaligned one for checking correctness. */
-    ALIGN_STK_16(uint16_t, cdf, 2, [17]);
+    ALIGN_STK_32(uint16_t, cdf, 2, [16]);
     MsacContext s_c, s_a;
 
     declare_func(unsigned, MsacContext *s, uint16_t *cdf, size_t n_symbols);
-    CHECK_SYMBOL_ADAPT( 4, 1,  5);
-    CHECK_SYMBOL_ADAPT( 8, 1,  8);
-    CHECK_SYMBOL_ADAPT(16, 4, 16);
+    CHECK_SYMBOL_ADAPT( 4, 1,  4);
+    CHECK_SYMBOL_ADAPT( 8, 1,  7);
+    CHECK_SYMBOL_ADAPT(16, 3, 15);
     report("decode_symbol");
 }
 
@@ -158,11 +157,11 @@
                     memcmp(cdf[0], cdf[1], sizeof(*cdf)))
                 {
                     if (fail())
-                        msac_dump(c_res, a_res, &s_c, &s_a, cdf[0], cdf[1], 2);
+                        msac_dump(c_res, a_res, &s_c, &s_a, cdf[0], cdf[1], 1);
                 }
             }
             if (cdf_update)
-                bench_new(&s_a, cdf[0]);
+                bench_new(&s_a, cdf[1]);
         }
     }