shithub: dav1d

Download patch

ref: fa1b265142e1409a986f01bd7abe115b308c1028
parent: 44d0de41d478b6b41a1ebbf1de012caa8d75cca0
author: Henrik Gramner <[email protected]>
date: Thu Apr 11 19:20:18 EDT 2019

x86-64: Add msac_decode_symbol_adapt SSE2 asm

Also make various minor optimizations/style fixes to the MSAC C functions.

--- a/src/cdf.c
+++ b/src/cdf.c
@@ -813,7 +813,7 @@
     AOM_CDF4(4096, 11264, 19328)
 };
 
-static const uint16_t default_kf_y_mode_cdf[5][5][N_INTRA_PRED_MODES + 1] = {
+static const uint16_t default_kf_y_mode_cdf[5][5][N_INTRA_PRED_MODES + 1 + 2] = {
     {
         { AOM_CDF13(15588, 17027, 19338, 20218, 20682, 21110, 21825, 23244,
                     24189, 28165, 29093, 30466) },
--- a/src/cdf.h
+++ b/src/cdf.h
@@ -34,11 +34,13 @@
 #include "src/ref.h"
 #include "src/thread_data.h"
 
+/* Buffers padded to [8] or [16] for SIMD where needed. */
+
 typedef struct CdfModeContext {
-    uint16_t y_mode[4][N_INTRA_PRED_MODES + 1];
+    uint16_t y_mode[4][N_INTRA_PRED_MODES + 1 + 2];
     uint16_t use_filter_intra[N_BS_SIZES][2];
     uint16_t filter_intra[5 + 1];
-    uint16_t uv_mode[2][N_INTRA_PRED_MODES][N_UV_INTRA_PRED_MODES + 1];
+    uint16_t uv_mode[2][N_INTRA_PRED_MODES][N_UV_INTRA_PRED_MODES + 1 + 1];
     uint16_t angle_delta[8][8];
     uint16_t filter[2][8][DAV1D_N_SWITCHABLE_FILTERS + 1];
     uint16_t newmv_mode[6][2];
@@ -66,7 +68,7 @@
     uint16_t txtp_intra[3][N_TX_SIZES][N_INTRA_PRED_MODES][N_TX_TYPES + 1];
     uint16_t skip[3][2];
     uint16_t skip_mode[3][2];
-    uint16_t partition[N_BL_LEVELS][4][N_PARTITIONS + 1];
+    uint16_t partition[N_BL_LEVELS][4][N_PARTITIONS + 1 + 5];
     uint16_t seg_pred[3][2];
     uint16_t seg_id[3][DAV1D_MAX_SEGMENTS + 1];
     uint16_t cfl_sign[8 + 1];
@@ -88,12 +90,12 @@
 typedef struct CdfCoefContext {
     uint16_t skip[N_TX_SIZES][13][2];
     uint16_t eob_bin_16[2][2][6];
-    uint16_t eob_bin_32[2][2][7];
+    uint16_t eob_bin_32[2][2][7 + 1];
     uint16_t eob_bin_64[2][2][8];
     uint16_t eob_bin_128[2][2][9];
-    uint16_t eob_bin_256[2][2][10];
-    uint16_t eob_bin_512[2][2][11];
-    uint16_t eob_bin_1024[2][2][12];
+    uint16_t eob_bin_256[2][2][10 + 6];
+    uint16_t eob_bin_512[2][2][11 + 5];
+    uint16_t eob_bin_1024[2][2][12 + 4];
     uint16_t eob_hi_bit[N_TX_SIZES][2][11 /*22*/][2];
     uint16_t eob_base_tok[N_TX_SIZES][2][4][4];
     uint16_t base_tok[N_TX_SIZES][2][41][5];
@@ -102,7 +104,7 @@
 } CdfCoefContext;
 
 typedef struct CdfMvComponent {
-    uint16_t classes[11 + 1];
+    uint16_t classes[11 + 1 + 4];
     uint16_t class0[2];
     uint16_t classN[10][2];
     uint16_t class0_fp[2][4 + 1];
@@ -119,7 +121,7 @@
 
 typedef struct CdfContext {
     CdfModeContext m;
-    uint16_t kfym[5][5][N_INTRA_PRED_MODES + 1];
+    uint16_t kfym[5][5][N_INTRA_PRED_MODES + 1 + 2];
     CdfCoefContext coef;
     CdfMvContext mv, dmv;
 } CdfContext;
--- a/src/decode.c
+++ b/src/decode.c
@@ -80,15 +80,15 @@
     const Dav1dFrameContext *const f = t->f;
     const int have_hp = f->frame_hdr->hp;
     const int sign = dav1d_msac_decode_bool_adapt(&ts->msac, mv_comp->sign);
-    const int cl = dav1d_msac_decode_symbol_adapt(&ts->msac,
-                                                  mv_comp->classes, 11);
+    const int cl = dav1d_msac_decode_symbol_adapt16(&ts->msac,
+                                                    mv_comp->classes, 11);
     int up, fp, hp;
 
     if (!cl) {
         up = dav1d_msac_decode_bool_adapt(&ts->msac, mv_comp->class0);
         if (have_fp) {
-            fp = dav1d_msac_decode_symbol_adapt(&ts->msac,
-                                                mv_comp->class0_fp[up], 4);
+            fp = dav1d_msac_decode_symbol_adapt4(&ts->msac,
+                                                 mv_comp->class0_fp[up], 4);
             hp = have_hp ? dav1d_msac_decode_bool_adapt(&ts->msac,
                                                         mv_comp->class0_hp) : 1;
         } else {
@@ -101,8 +101,8 @@
             up |= dav1d_msac_decode_bool_adapt(&ts->msac,
                                                mv_comp->classN[n]) << n;
         if (have_fp) {
-            fp = dav1d_msac_decode_symbol_adapt(&ts->msac,
-                                                mv_comp->classN_fp, 4);
+            fp = dav1d_msac_decode_symbol_adapt4(&ts->msac,
+                                                 mv_comp->classN_fp, 4);
             hp = have_hp ? dav1d_msac_decode_bool_adapt(&ts->msac,
                                                         mv_comp->classN_hp) : 1;
         } else {
@@ -119,8 +119,8 @@
 static void read_mv_residual(Dav1dTileContext *const t, mv *const ref_mv,
                              CdfMvContext *const mv_cdf, const int have_fp)
 {
-    switch (dav1d_msac_decode_symbol_adapt(&t->ts->msac, t->ts->cdf.mv.joint,
-                                           N_MV_JOINTS))
+    switch (dav1d_msac_decode_symbol_adapt4(&t->ts->msac, t->ts->cdf.mv.joint,
+                                            N_MV_JOINTS))
     {
     case MV_JOINT_HV:
         ref_mv->y += read_mv_component_diff(t, &mv_cdf->comp[0], have_fp);
@@ -379,7 +379,7 @@
 {
     Dav1dTileState *const ts = t->ts;
     const Dav1dFrameContext *const f = t->f;
-    const int pal_sz = b->pal_sz[pl] = dav1d_msac_decode_symbol_adapt(&ts->msac,
+    const int pal_sz = b->pal_sz[pl] = dav1d_msac_decode_symbol_adapt8(&ts->msac,
                                            ts->cdf.m.pal_sz[pl][sz_ctx], 7) + 2;
     uint16_t cache[16], used_cache[8];
     int l_cache = pl ? t->pal_sz_uv[1][by4] : t->l.pal_sz[by4];
@@ -595,7 +595,7 @@
         const int last = imax(0, i - h4 * 4 + 1);
         order_palette(pal_idx, stride, i, first, last, order, ctx);
         for (int j = first, m = 0; j >= last; j--, m++) {
-            const int color_idx = dav1d_msac_decode_symbol_adapt(&ts->msac,
+            const int color_idx = dav1d_msac_decode_symbol_adapt8(&ts->msac,
                                       color_map_cdf[ctx[m]], b->pal_sz[pl]);
             pal_idx[(i - j) * stride + j] = order[m][color_idx];
         }
@@ -811,7 +811,7 @@
                 const unsigned pred_seg_id =
                     get_cur_frame_segid(t->by, t->bx, have_top, have_left,
                                         &seg_ctx, f->cur_segmap, f->b4_stride);
-                const unsigned diff = dav1d_msac_decode_symbol_adapt(&ts->msac,
+                const unsigned diff = dav1d_msac_decode_symbol_adapt8(&ts->msac,
                                           ts->cdf.m.seg_id[seg_ctx],
                                           DAV1D_MAX_SEGMENTS);
                 const unsigned last_active_seg_id =
@@ -883,7 +883,7 @@
             if (b->skip) {
                 b->seg_id = pred_seg_id;
             } else {
-                const unsigned diff = dav1d_msac_decode_symbol_adapt(&ts->msac,
+                const unsigned diff = dav1d_msac_decode_symbol_adapt8(&ts->msac,
                                           ts->cdf.m.seg_id[seg_ctx],
                                           DAV1D_MAX_SEGMENTS);
                 const unsigned last_active_seg_id =
@@ -932,8 +932,8 @@
         memcpy(prev_delta_lf, ts->last_delta_lf, 4);
 
         if (have_delta_q) {
-            int delta_q = dav1d_msac_decode_symbol_adapt(&ts->msac,
-                                                         ts->cdf.m.delta_q, 4);
+            int delta_q = dav1d_msac_decode_symbol_adapt4(&ts->msac,
+                                                          ts->cdf.m.delta_q, 4);
             if (delta_q == 3) {
                 const int n_bits = 1 + dav1d_msac_decode_bools(&ts->msac, 3);
                 delta_q = dav1d_msac_decode_bools(&ts->msac, n_bits) +
@@ -953,7 +953,7 @@
                     f->cur.p.layout != DAV1D_PIXEL_LAYOUT_I400 ? 4 : 2 : 1;
 
                 for (int i = 0; i < n_lfs; i++) {
-                    int delta_lf = dav1d_msac_decode_symbol_adapt(&ts->msac,
+                    int delta_lf = dav1d_msac_decode_symbol_adapt4(&ts->msac,
                         ts->cdf.m.delta_lf[i + f->frame_hdr->delta.lf.multi], 4);
                     if (delta_lf == 3) {
                         const int n_bits = 1 + dav1d_msac_decode_bools(&ts->msac, 3);
@@ -1018,8 +1018,8 @@
             ts->cdf.m.y_mode[dav1d_ymode_size_context[bs]] :
             ts->cdf.kfym[dav1d_intra_mode_context[t->a->mode[bx4]]]
                         [dav1d_intra_mode_context[t->l.mode[by4]]];
-        b->y_mode = dav1d_msac_decode_symbol_adapt(&ts->msac, ymode_cdf,
-                                                   N_INTRA_PRED_MODES);
+        b->y_mode = dav1d_msac_decode_symbol_adapt16(&ts->msac, ymode_cdf,
+                                                     N_INTRA_PRED_MODES);
         if (DEBUG_BLOCK_INFO)
             printf("Post-ymode[%d]: r=%d\n", b->y_mode, ts->msac.rng);
 
@@ -1028,7 +1028,7 @@
             b->y_mode <= VERT_LEFT_PRED)
         {
             uint16_t *const acdf = ts->cdf.m.angle_delta[b->y_mode - VERT_PRED];
-            const int angle = dav1d_msac_decode_symbol_adapt(&ts->msac, acdf, 7);
+            const int angle = dav1d_msac_decode_symbol_adapt8(&ts->msac, acdf, 7);
             b->y_angle = angle - 3;
         } else {
             b->y_angle = 0;
@@ -1038,7 +1038,7 @@
             const int cfl_allowed = f->frame_hdr->segmentation.lossless[b->seg_id] ?
                 cbw4 == 1 && cbh4 == 1 : !!(cfl_allowed_mask & (1 << bs));
             uint16_t *const uvmode_cdf = ts->cdf.m.uv_mode[cfl_allowed][b->y_mode];
-            b->uv_mode = dav1d_msac_decode_symbol_adapt(&ts->msac, uvmode_cdf,
+            b->uv_mode = dav1d_msac_decode_symbol_adapt16(&ts->msac, uvmode_cdf,
                              N_UV_INTRA_PRED_MODES - !cfl_allowed);
             if (DEBUG_BLOCK_INFO)
                 printf("Post-uvmode[%d]: r=%d\n", b->uv_mode, ts->msac.rng);
@@ -1045,13 +1045,13 @@
 
             if (b->uv_mode == CFL_PRED) {
 #define SIGN(a) (!!(a) + ((a) > 0))
-                const int sign = dav1d_msac_decode_symbol_adapt(&ts->msac,
+                const int sign = dav1d_msac_decode_symbol_adapt8(&ts->msac,
                                      ts->cdf.m.cfl_sign, 8) + 1;
                 const int sign_u = sign * 0x56 >> 8, sign_v = sign - sign_u * 3;
                 assert(sign_u == sign / 3);
                 if (sign_u) {
                     const int ctx = (sign_u == 2) * 3 + sign_v;
-                    b->cfl_alpha[0] = dav1d_msac_decode_symbol_adapt(&ts->msac,
+                    b->cfl_alpha[0] = dav1d_msac_decode_symbol_adapt16(&ts->msac,
                                           ts->cdf.m.cfl_alpha[ctx], 16) + 1;
                     if (sign_u == 1) b->cfl_alpha[0] = -b->cfl_alpha[0];
                 } else {
@@ -1059,7 +1059,7 @@
                 }
                 if (sign_v) {
                     const int ctx = (sign_v == 2) * 3 + sign_u;
-                    b->cfl_alpha[1] = dav1d_msac_decode_symbol_adapt(&ts->msac,
+                    b->cfl_alpha[1] = dav1d_msac_decode_symbol_adapt16(&ts->msac,
                                           ts->cdf.m.cfl_alpha[ctx], 16) + 1;
                     if (sign_v == 1) b->cfl_alpha[1] = -b->cfl_alpha[1];
                 } else {
@@ -1073,7 +1073,7 @@
                        b->uv_mode <= VERT_LEFT_PRED)
             {
                 uint16_t *const acdf = ts->cdf.m.angle_delta[b->uv_mode - VERT_PRED];
-                const int angle = dav1d_msac_decode_symbol_adapt(&ts->msac, acdf, 7);
+                const int angle = dav1d_msac_decode_symbol_adapt8(&ts->msac, acdf, 7);
                 b->uv_angle = angle - 3;
             } else {
                 b->uv_angle = 0;
@@ -1113,7 +1113,7 @@
                                       ts->cdf.m.use_filter_intra[bs]);
             if (is_filter) {
                 b->y_mode = FILTER_PRED;
-                b->y_angle = dav1d_msac_decode_symbol_adapt(&ts->msac,
+                b->y_angle = dav1d_msac_decode_symbol_adapt4(&ts->msac,
                                  ts->cdf.m.filter_intra, 5);
             }
             if (DEBUG_BLOCK_INFO)
@@ -1156,7 +1156,7 @@
             if (f->frame_hdr->txfm_mode == DAV1D_TX_SWITCHABLE && t_dim->max > TX_4X4) {
                 const int tctx = get_tx_ctx(t->a, &t->l, t_dim, by4, bx4);
                 uint16_t *const tx_cdf = ts->cdf.m.txsz[t_dim->max - 1][tctx];
-                int depth = dav1d_msac_decode_symbol_adapt(&ts->msac, tx_cdf,
+                int depth = dav1d_msac_decode_symbol_adapt4(&ts->msac, tx_cdf,
                                 imin(t_dim->max + 1, 3));
 
                 while (depth--) {
@@ -1474,7 +1474,7 @@
                              ts->tiling.col_end, ts->tiling.row_start,
                              ts->tiling.row_end, f->libaom_cm);
 
-            b->inter_mode = dav1d_msac_decode_symbol_adapt(&ts->msac,
+            b->inter_mode = dav1d_msac_decode_symbol_adapt8(&ts->msac,
                                 ts->cdf.m.comp_inter_mode[ctx],
                                 N_COMP_INTER_PRED_MODES);
             if (DEBUG_BLOCK_INFO)
@@ -1583,7 +1583,7 @@
                                    dav1d_msac_decode_bool_adapt(&ts->msac,
                                        ts->cdf.m.wedge_comp[ctx]);
                     if (b->comp_type == COMP_INTER_WEDGE)
-                        b->wedge_idx = dav1d_msac_decode_symbol_adapt(&ts->msac,
+                        b->wedge_idx = dav1d_msac_decode_symbol_adapt16(&ts->msac,
                                            ts->cdf.m.wedge_idx[ctx], 16);
                 } else {
                     b->comp_type = COMP_INTER_SEG;
@@ -1737,7 +1737,7 @@
                 dav1d_msac_decode_bool_adapt(&ts->msac,
                                              ts->cdf.m.interintra[ii_sz_grp]))
             {
-                b->interintra_mode = dav1d_msac_decode_symbol_adapt(&ts->msac,
+                b->interintra_mode = dav1d_msac_decode_symbol_adapt4(&ts->msac,
                                          ts->cdf.m.interintra_mode[ii_sz_grp],
                                          N_INTER_INTRA_PRED_MODES);
                 const int wedge_ctx = dav1d_wedge_ctx_lut[bs];
@@ -1745,7 +1745,7 @@
                                      dav1d_msac_decode_bool_adapt(&ts->msac,
                                          ts->cdf.m.interintra_wedge[wedge_ctx]);
                 if (b->interintra_type == INTER_INTRA_WEDGE)
-                    b->wedge_idx = dav1d_msac_decode_symbol_adapt(&ts->msac,
+                    b->wedge_idx = dav1d_msac_decode_symbol_adapt16(&ts->msac,
                                        ts->cdf.m.wedge_idx[wedge_ctx], 16);
             } else {
                 b->interintra_type = INTER_INTRA_NONE;
@@ -1778,7 +1778,7 @@
                     f->frame_hdr->warp_motion && (mask[0] | mask[1]);
 
                 b->motion_mode = allow_warp ?
-                    dav1d_msac_decode_symbol_adapt(&ts->msac,
+                    dav1d_msac_decode_symbol_adapt4(&ts->msac,
                         ts->cdf.m.motion_mode[bs], 3) :
                     dav1d_msac_decode_bool_adapt(&ts->msac, ts->cdf.m.obmc[bs]);
                 if (b->motion_mode == MM_WARP) {
@@ -1817,7 +1817,7 @@
                 const int comp = b->comp_type != COMP_INTER_NONE;
                 const int ctx1 = get_filter_ctx(t->a, &t->l, comp, 0, b->ref[0],
                                                 by4, bx4);
-                filter[0] = dav1d_msac_decode_symbol_adapt(&ts->msac,
+                filter[0] = dav1d_msac_decode_symbol_adapt4(&ts->msac,
                                ts->cdf.m.filter[0][ctx1],
                                DAV1D_N_SWITCHABLE_FILTERS);
                 if (f->seq_hdr->dual_filter) {
@@ -1826,7 +1826,7 @@
                     if (DEBUG_BLOCK_INFO)
                         printf("Post-subpel_filter1[%d,ctx=%d]: r=%d\n",
                                filter[0], ctx1, ts->msac.rng);
-                    filter[1] = dav1d_msac_decode_symbol_adapt(&ts->msac,
+                    filter[1] = dav1d_msac_decode_symbol_adapt4(&ts->msac,
                                     ts->cdf.m.filter[1][ctx2],
                                     DAV1D_N_SWITCHABLE_FILTERS);
                     if (DEBUG_BLOCK_INFO)
@@ -2021,7 +2021,7 @@
         } else {
             const unsigned n_part = bl == BL_8X8 ? N_SUB8X8_PARTITIONS :
                 bl == BL_128X128 ? N_PARTITIONS - 2 : N_PARTITIONS;
-            bp = dav1d_msac_decode_symbol_adapt(&t->ts->msac, pc, n_part);
+            bp = dav1d_msac_decode_symbol_adapt16(&t->ts->msac, pc, n_part);
             if (f->cur.p.layout == DAV1D_PIXEL_LAYOUT_I422 &&
                 (bp == PARTITION_V || bp == PARTITION_V4 ||
                  bp == PARTITION_T_LEFT_SPLIT || bp == PARTITION_T_RIGHT_SPLIT))
@@ -2365,7 +2365,7 @@
     Dav1dTileState *const ts = t->ts;
 
     if (frame_type == DAV1D_RESTORATION_SWITCHABLE) {
-        const int filter = dav1d_msac_decode_symbol_adapt(&ts->msac,
+        const int filter = dav1d_msac_decode_symbol_adapt4(&ts->msac,
                                ts->cdf.m.restore_switchable, 3);
         lr->type = filter ? filter == 2 ? DAV1D_RESTORATION_SGRPROJ :
                                           DAV1D_RESTORATION_WIENER :
--- a/src/meson.build
+++ b/src/meson.build
@@ -119,6 +119,7 @@
         # NASM source files
         libdav1d_sources_asm = files(
             'x86/cpuid.asm',
+            'x86/msac.asm',
         )
 
         if dav1d_bitdepths.contains('8')
--- a/src/msac.c
+++ b/src/msac.c
@@ -58,8 +58,8 @@
  * necessary), and stores them back in the decoder context.
  * dif: The new value of dif.
  * rng: The new value of the range. */
-static inline void ctx_norm(MsacContext *s, ec_win dif, uint32_t rng) {
-    const uint16_t d = 15 - (31 ^ clz(rng));
+static inline void ctx_norm(MsacContext *s, ec_win dif, unsigned rng) {
+    const int d = 15 ^ (31 ^ clz(rng));
     assert(rng <= 65535U);
     s->cnt -= d;
     s->dif = ((dif + 1) << d) - 1; /* Shift in 1s in the LSBs */
@@ -69,18 +69,17 @@
 }
 
 unsigned dav1d_msac_decode_bool_equi(MsacContext *const s) {
-    ec_win v, vw, dif = s->dif;
-    uint16_t r = s->rng;
-    unsigned ret;
+    ec_win vw, dif = s->dif;
+    unsigned ret, v, r = s->rng;
     assert((dif >> (EC_WIN_SIZE - 16)) < r);
     // When the probability is 1/2, f = 16384 >> EC_PROB_SHIFT = 256 and we can
     // replace the multiply with a simple shift.
     v = ((r >> 8) << 7) + EC_MIN_PROB;
-    vw   = v << (EC_WIN_SIZE - 16);
+    vw   = (ec_win)v << (EC_WIN_SIZE - 16);
     ret  = dif >= vw;
     dif -= ret*vw;
     v   += ret*(r - 2*v);
-    ctx_norm(s, dif, (unsigned) v);
+    ctx_norm(s, dif, v);
     return !ret;
 }
 
@@ -88,27 +87,26 @@
  * f: The probability that the bit is one
  * Return: The value decoded (0 or 1). */
 unsigned dav1d_msac_decode_bool(MsacContext *const s, const unsigned f) {
-    ec_win v, vw, dif = s->dif;
-    uint16_t r = s->rng;
-    unsigned ret;
+    ec_win vw, dif = s->dif;
+    unsigned ret, v, r = s->rng;
     assert((dif >> (EC_WIN_SIZE - 16)) < r);
     v = ((r >> 8) * (f >> EC_PROB_SHIFT) >> (7 - EC_PROB_SHIFT)) + EC_MIN_PROB;
-    vw   = v << (EC_WIN_SIZE - 16);
+    vw   = (ec_win)v << (EC_WIN_SIZE - 16);
     ret  = dif >= vw;
     dif -= ret*vw;
     v   += ret*(r - 2*v);
-    ctx_norm(s, dif, (unsigned) v);
+    ctx_norm(s, dif, v);
     return !ret;
 }
 
-unsigned dav1d_msac_decode_bools(MsacContext *const c, const unsigned l) {
-    int v = 0;
-    for (int n = (int) l - 1; n >= 0; n--)
-        v = (v << 1) | dav1d_msac_decode_bool_equi(c);
+unsigned dav1d_msac_decode_bools(MsacContext *const s, unsigned n) {
+    unsigned v = 0;
+    while (n--)
+        v = (v << 1) | dav1d_msac_decode_bool_equi(s);
     return v;
 }
 
-int dav1d_msac_decode_subexp(MsacContext *const c, const int ref,
+int dav1d_msac_decode_subexp(MsacContext *const s, const int ref,
                              const int n, const unsigned k)
 {
     int i = 0;
@@ -115,32 +113,31 @@
     int a = 0;
     int b = k;
     while ((2 << b) < n) {
-        if (!dav1d_msac_decode_bool_equi(c)) break;
+        if (!dav1d_msac_decode_bool_equi(s)) break;
         b = k + i++;
         a = (1 << b);
     }
-    const unsigned v = dav1d_msac_decode_bools(c, b) + a;
+    const unsigned v = dav1d_msac_decode_bools(s, b) + a;
     return ref * 2 <= n ? inv_recenter(ref, v) :
                           n - 1 - inv_recenter(n - 1 - ref, v);
 }
 
-int dav1d_msac_decode_uniform(MsacContext *const c, const unsigned n) {
+int dav1d_msac_decode_uniform(MsacContext *const s, const unsigned n) {
     assert(n > 0);
     const int l = ulog2(n) + 1;
     assert(l > 1);
     const unsigned m = (1 << l) - n;
-    const unsigned v = dav1d_msac_decode_bools(c, l - 1);
-    return v < m ? v : (v << 1) - m + dav1d_msac_decode_bool_equi(c);
+    const unsigned v = dav1d_msac_decode_bools(s, l - 1);
+    return v < m ? v : (v << 1) - m + dav1d_msac_decode_bool_equi(s);
 }
 
 /* Decodes a symbol given an inverse cumulative distribution function (CDF)
  * table in Q15. */
 static unsigned decode_symbol(MsacContext *const s, const uint16_t *const cdf,
-                              const unsigned n_symbols)
+                              const size_t n_symbols)
 {
-    ec_win u, v = s->rng, r = s->rng >> 8;
-    const ec_win c = s->dif >> (EC_WIN_SIZE - 16);
-    unsigned ret = 0;
+    const unsigned c = s->dif >> (EC_WIN_SIZE - 16);
+    unsigned u, v = s->rng, r = s->rng >> 8, ret = 0;
 
     assert(!cdf[n_symbols - 1]);
 
@@ -153,39 +150,34 @@
 
     assert(u <= s->rng);
 
-    ctx_norm(s, s->dif - (v << (EC_WIN_SIZE - 16)), (unsigned) (u - v));
+    ctx_norm(s, s->dif - ((ec_win)v << (EC_WIN_SIZE - 16)), u - v);
     return ret - 1;
 }
 
-static void update_cdf(uint16_t *const cdf, const unsigned val,
-                       const unsigned n_symbols)
+unsigned dav1d_msac_decode_symbol_adapt_c(MsacContext *const s,
+                                          uint16_t *const cdf,
+                                          const size_t n_symbols)
 {
-    const unsigned count = cdf[n_symbols];
-    const int rate = ((count >> 4) | 4) + (n_symbols > 3);
-    unsigned i;
-    for (i = 0; i < val; i++)
-        cdf[i] += (32768 - cdf[i]) >> rate;
-    for (; i < n_symbols - 1; i++)
-        cdf[i] -= cdf[i] >> rate;
-    cdf[n_symbols] = count + (count < 32);
-}
-
-unsigned dav1d_msac_decode_symbol_adapt(MsacContext *const c,
-                                        uint16_t *const cdf,
-                                        const unsigned n_symbols)
-{
-    const unsigned val = decode_symbol(c, cdf, n_symbols);
-    if(c->allow_update_cdf)
-        update_cdf(cdf, val, n_symbols);
+    const unsigned val = decode_symbol(s, cdf, n_symbols);
+    if (s->allow_update_cdf) {
+        const unsigned count = cdf[n_symbols];
+        const int rate = ((count >> 4) | 4) + (n_symbols > 3);
+        unsigned i;
+        for (i = 0; i < val; i++)
+            cdf[i] += (32768 - cdf[i]) >> rate;
+        for (; i < n_symbols - 1; i++)
+            cdf[i] -= cdf[i] >> rate;
+        cdf[n_symbols] = count + (count < 32);
+    }
     return val;
 }
 
-unsigned dav1d_msac_decode_bool_adapt(MsacContext *const c,
+unsigned dav1d_msac_decode_bool_adapt(MsacContext *const s,
                                       uint16_t *const cdf)
 {
-    const unsigned bit = dav1d_msac_decode_bool(c, *cdf);
+    const unsigned bit = dav1d_msac_decode_bool(s, *cdf);
 
-    if(c->allow_update_cdf){
+    if (s->allow_update_cdf) {
         // update_cdf() specialized for boolean CDFs
         const unsigned count = cdf[1];
         const int rate = (count >> 4) | 4;
--- a/src/msac.h
+++ b/src/msac.h
@@ -38,20 +38,37 @@
     const uint8_t *buf_pos;
     const uint8_t *buf_end;
     ec_win dif;
-    uint16_t rng;
+    unsigned rng;
     int cnt;
     int allow_update_cdf;
 } MsacContext;
 
-void dav1d_msac_init(MsacContext *c, const uint8_t *data, size_t sz,
+void dav1d_msac_init(MsacContext *s, const uint8_t *data, size_t sz,
                      int disable_cdf_update_flag);
-unsigned dav1d_msac_decode_symbol_adapt(MsacContext *s, uint16_t *cdf,
-                                        const unsigned n_symbols);
-unsigned dav1d_msac_decode_bool_equi(MsacContext *const s);
+unsigned dav1d_msac_decode_symbol_adapt_c(MsacContext *s, uint16_t *cdf,
+                                          size_t n_symbols);
+unsigned dav1d_msac_decode_bool_equi(MsacContext *s);
 unsigned dav1d_msac_decode_bool(MsacContext *s, unsigned f);
 unsigned dav1d_msac_decode_bool_adapt(MsacContext *s, uint16_t *cdf);
-unsigned dav1d_msac_decode_bools(MsacContext *c, unsigned l);
-int dav1d_msac_decode_subexp(MsacContext *c, int ref, int n, unsigned k);
-int dav1d_msac_decode_uniform(MsacContext *c, unsigned n);
+unsigned dav1d_msac_decode_bools(MsacContext *s, unsigned n);
+int dav1d_msac_decode_subexp(MsacContext *s, int ref, int n, unsigned k);
+int dav1d_msac_decode_uniform(MsacContext *s, unsigned n);
+
+/* Supported n_symbols ranges: adapt4: 1-5, adapt8: 1-8, adapt16: 4-16 */
+#if ARCH_X86_64 && HAVE_ASM
+unsigned dav1d_msac_decode_symbol_adapt4_sse2(MsacContext *s, uint16_t *cdf,
+                                              size_t n_symbols);
+unsigned dav1d_msac_decode_symbol_adapt8_sse2(MsacContext *s, uint16_t *cdf,
+                                              size_t n_symbols);
+unsigned dav1d_msac_decode_symbol_adapt16_sse2(MsacContext *s, uint16_t *cdf,
+                                               size_t n_symbols);
+#define dav1d_msac_decode_symbol_adapt4  dav1d_msac_decode_symbol_adapt4_sse2
+#define dav1d_msac_decode_symbol_adapt8  dav1d_msac_decode_symbol_adapt8_sse2
+#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_sse2
+#else
+#define dav1d_msac_decode_symbol_adapt4  dav1d_msac_decode_symbol_adapt_c
+#define dav1d_msac_decode_symbol_adapt8  dav1d_msac_decode_symbol_adapt_c
+#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt_c
+#endif
 
 #endif /* DAV1D_SRC_MSAC_H */
--- a/src/recon_tmpl.c
+++ b/src/recon_tmpl.c
@@ -107,7 +107,9 @@
             uint16_t *const txtp_cdf = intra ?
                        ts->cdf.m.txtp_intra[set_idx][t_dim->min][y_mode_nofilt] :
                        ts->cdf.m.txtp_inter[set_idx][t_dim->min];
-            idx = dav1d_msac_decode_symbol_adapt(&ts->msac, txtp_cdf, set_cnt);
+            idx = (set_cnt <= 8 ? dav1d_msac_decode_symbol_adapt8 :
+                     dav1d_msac_decode_symbol_adapt16)(&ts->msac, txtp_cdf, set_cnt);
+
             if (dbg)
             printf("Post-txtp[%d->%d][%d->%d][%d][%d->%d]: r=%d\n",
                    set, set_idx, tx, t_dim->min, intra ? (int)y_mode_nofilt : -1,
@@ -122,19 +124,19 @@
     const enum TxClass tx_class = dav1d_tx_type_class[*txtp];
     const int is_1d = tx_class != TX_CLASS_2D;
     switch (tx2dszctx) {
-#define case_sz(sz, bin) \
+#define case_sz(sz, bin, ns) \
     case sz: { \
         uint16_t *const eob_bin_cdf = ts->cdf.coef.eob_bin_##bin[chroma][is_1d]; \
-        eob_bin = dav1d_msac_decode_symbol_adapt(&ts->msac, eob_bin_cdf, 5 + sz); \
+        eob_bin = dav1d_msac_decode_symbol_adapt##ns(&ts->msac, eob_bin_cdf, 5 + sz); \
         break; \
     }
-    case_sz(0,   16);
-    case_sz(1,   32);
-    case_sz(2,   64);
-    case_sz(3,  128);
-    case_sz(4,  256);
-    case_sz(5,  512);
-    case_sz(6, 1024);
+    case_sz(0,   16,  4);
+    case_sz(1,   32,  8);
+    case_sz(2,   64,  8);
+    case_sz(3,  128,  8);
+    case_sz(4,  256, 16);
+    case_sz(5,  512, 16);
+    case_sz(6, 1024, 16);
 #undef case_sz
     }
     if (dbg)
@@ -179,8 +181,8 @@
         uint16_t *const lo_cdf = is_last ?
             ts->cdf.coef.eob_base_tok[t_dim->ctx][chroma][ctx] :
             ts->cdf.coef.base_tok[t_dim->ctx][chroma][ctx];
-        int tok = dav1d_msac_decode_symbol_adapt(&ts->msac, lo_cdf,
-                                                 4 - is_last) + is_last;
+        int tok = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf,
+                                                  4 - is_last) + is_last;
         if (dbg)
         printf("Post-lo_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
                t_dim->ctx, chroma, ctx, i, rc, tok, ts->msac.rng);
@@ -190,7 +192,7 @@
         if (tok == 3) {
             const int br_ctx = get_br_ctx(levels, rc, tx, tx_class);
             do {
-                const int tok_br = dav1d_msac_decode_symbol_adapt(&ts->msac,
+                const int tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
                                        br_cdf[br_ctx], 4);
                 if (dbg)
                 printf("Post-hi_tok[%d][%d][%d][%d=%d=%d->%d]: r=%d\n",
--- /dev/null
+++ b/src/x86/msac.asm
@@ -1,0 +1,287 @@
+; Copyright © 2019, VideoLAN and dav1d authors
+; Copyright © 2019, Two Orioles, LLC
+; All rights reserved.
+;
+; Redistribution and use in source and binary forms, with or without
+; modification, are permitted provided that the following conditions are met:
+;
+; 1. Redistributions of source code must retain the above copyright notice, this
+;    list of conditions and the following disclaimer.
+;
+; 2. Redistributions in binary form must reproduce the above copyright notice,
+;    this list of conditions and the following disclaimer in the documentation
+;    and/or other materials provided with the distribution.
+;
+; THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+; ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+; WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+; DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+; ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+; (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+; ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+; (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+; SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+%include "config.asm"
+%include "ext/x86/x86inc.asm"
+
+%if ARCH_X86_64
+
+SECTION_RODATA 64 ; avoids cacheline splits
+
+dw 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0
+pw_0xff00: times 8 dw 0xff00
+pw_32:     times 8 dw 32
+
+struc msac
+    .buf:        resq 1
+    .end:        resq 1
+    .dif:        resq 1
+    .rng:        resd 1
+    .cnt:        resd 1
+    .update_cdf: resd 1
+endstruc
+
+%define m(x) mangle(private_prefix %+ _ %+ x %+ SUFFIX)
+
+SECTION .text
+
+%if WIN64
+DECLARE_REG_TMP 3
+%define buf rsp+8 ; shadow space
+%else
+DECLARE_REG_TMP 0
+%define buf rsp-40 ; red zone
+%endif
+
+INIT_XMM sse2
+cglobal msac_decode_symbol_adapt4, 3, 7, 6, s, cdf, ns
+    movd           m2, [sq+msac.rng]
+    movq           m1, [cdfq]
+    lea           rax, [pw_0xff00]
+    movq           m3, [sq+msac.dif]
+    mov           r3d, [sq+msac.update_cdf]
+    mov           r4d, nsd
+    neg           nsq
+    pshuflw        m2, m2, q0000
+    movd     [buf+12], m2
+    pand           m2, [rax]
+    mova           m0, m1
+    psrlw          m1, 6
+    psllw          m1, 7
+    pmulhuw        m1, m2
+    movq           m2, [rax+nsq*2]
+    pshuflw        m3, m3, q3333
+    paddw          m1, m2
+    mova     [buf+16], m1
+    psubusw        m1, m3
+    pxor           m2, m2
+    pcmpeqw        m1, m2 ; c >= v
+    pmovmskb      eax, m1
+    test          r3d, r3d
+    jz .renorm ; !allow_update_cdf
+
+; update_cdf:
+    movzx         r3d, word [cdfq+r4*2] ; count
+    pcmpeqw        m2, m2
+    mov           r2d, r3d
+    shr           r3d, 4
+    cmp           r4d, 4
+    sbb           r3d, -5 ; (count >> 4) + (n_symbols > 3) + 4
+    cmp           r2d, 32
+    adc           r2d, 0  ; count + (count < 32)
+    movd           m3, r3d
+    pavgw          m2, m1 ; i >= val ? -1 : 32768
+    psubw          m2, m0 ; for (i = 0; i < val; i++)
+    psubw          m0, m1 ;     cdf[i] += (32768 - cdf[i]) >> rate;
+    psraw          m2, m3 ; for (; i < n_symbols - 1; i++)
+    paddw          m0, m2 ;     cdf[i] += ((  -1 - cdf[i]) >> rate) + 1;
+    movq       [cdfq], m0
+    mov   [cdfq+r4*2], r2w
+
+.renorm:
+    tzcnt         eax, eax
+    mov            r4, [sq+msac.dif]
+    movzx         r1d, word [buf+rax+16] ; v
+    movzx         r2d, word [buf+rax+14] ; u
+    shr           eax, 1
+.renorm2:
+    not            r4
+    sub           r2d, r1d ; rng
+    shl            r1, 48
+    add            r4, r1  ; ~dif
+    mov           r1d, [sq+msac.cnt]
+    movifnidn      t0, sq
+    bsr           ecx, r2d
+    xor           ecx, 15  ; d
+    shl           r2d, cl
+    shl            r4, cl
+    mov [t0+msac.rng], r2d
+    not            r4
+    sub           r1d, ecx
+    jge .end ; no refill required
+
+; refill:
+    mov            r2, [t0+msac.buf]
+    mov           rcx, [t0+msac.end]
+    lea            r5, [r2+8]
+    cmp            r5, rcx
+    jg .refill_eob
+    mov            r2, [r2]
+    lea           ecx, [r1+23]
+    add           r1d, 16
+    shr           ecx, 3   ; shift_bytes
+    bswap          r2
+    sub            r5, rcx
+    shl           ecx, 3   ; shift_bits
+    shr            r2, cl
+    sub           ecx, r1d ; shift_bits - 16 - cnt
+    mov           r1d, 48
+    shl            r2, cl
+    mov [t0+msac.buf], r5
+    sub           r1d, ecx ; cnt + 64 - shift_bits
+    xor            r4, r2
+.end:
+    mov [t0+msac.cnt], r1d
+    mov [t0+msac.dif], r4
+    RET
+.refill_eob: ; avoid overreading the input buffer
+    mov            r5, rcx
+    mov           ecx, 40
+    sub           ecx, r1d ; c
+.refill_eob_loop:
+    cmp            r2, r5
+    jge .refill_eob_end    ; eob reached
+    movzx         r1d, byte [r2]
+    inc            r2
+    shl            r1, cl
+    xor            r4, r1
+    sub           ecx, 8
+    jge .refill_eob_loop
+.refill_eob_end:
+    mov           r1d, 40
+    sub           r1d, ecx
+    mov [t0+msac.buf], r2
+    mov [t0+msac.dif], r4
+    mov [t0+msac.cnt], r1d
+    RET
+
+cglobal msac_decode_symbol_adapt8, 3, 7, 6, s, cdf, ns
+    movd           m2, [sq+msac.rng]
+    movu           m1, [cdfq]
+    lea           rax, [pw_0xff00]
+    movq           m3, [sq+msac.dif]
+    mov           r3d, [sq+msac.update_cdf]
+    mov           r4d, nsd
+    neg           nsq
+    pshuflw        m2, m2, q0000
+    movd     [buf+12], m2
+    punpcklqdq     m2, m2
+    mova           m0, m1
+    psrlw          m1, 6
+    pand           m2, [rax]
+    psllw          m1, 7
+    pmulhuw        m1, m2
+    movu           m2, [rax+nsq*2]
+    pshuflw        m3, m3, q3333
+    paddw          m1, m2
+    punpcklqdq     m3, m3
+    mova     [buf+16], m1
+    psubusw        m1, m3
+    pxor           m2, m2
+    pcmpeqw        m1, m2
+    pmovmskb      eax, m1
+    test          r3d, r3d
+    jz m(msac_decode_symbol_adapt4).renorm
+    movzx         r3d, word [cdfq+r4*2]
+    pcmpeqw        m2, m2
+    mov           r2d, r3d
+    shr           r3d, 4
+    cmp           r4d, 4 ; may be called with n_symbols < 4
+    sbb           r3d, -5
+    cmp           r2d, 32
+    adc           r2d, 0
+    movd           m3, r3d
+    pavgw          m2, m1
+    psubw          m2, m0
+    psubw          m0, m1
+    psraw          m2, m3
+    paddw          m0, m2
+    movu       [cdfq], m0
+    mov   [cdfq+r4*2], r2w
+    jmp m(msac_decode_symbol_adapt4).renorm
+
+cglobal msac_decode_symbol_adapt16, 3, 7, 6, s, cdf, ns
+    movd           m4, [sq+msac.rng]
+    movu           m2, [cdfq]
+    lea           rax, [pw_0xff00]
+    movu           m3, [cdfq+16]
+    movq           m5, [sq+msac.dif]
+    mov           r3d, [sq+msac.update_cdf]
+    mov           r4d, nsd
+    neg           nsq
+%if WIN64
+    sub           rsp, 48 ; need 36 bytes, shadow space is only 32
+%endif
+    pshuflw        m4, m4, q0000
+    movd      [buf-4], m4
+    punpcklqdq     m4, m4
+    mova           m0, m2
+    psrlw          m2, 6
+    mova           m1, m3
+    psrlw          m3, 6
+    pand           m4, [rax]
+    psllw          m2, 7
+    psllw          m3, 7
+    pmulhuw        m2, m4
+    pmulhuw        m3, m4
+    movu           m4, [rax+nsq*2]
+    pshuflw        m5, m5, q3333
+    paddw          m2, m4
+    psubw          m4, [rax-pw_0xff00+pw_32]
+    punpcklqdq     m5, m5
+    paddw          m3, m4
+    mova        [buf], m2
+    mova     [buf+16], m3
+    psubusw        m2, m5
+    psubusw        m3, m5
+    pxor           m4, m4
+    pcmpeqw        m2, m4
+    pcmpeqw        m3, m4
+    packsswb       m5, m2, m3
+    pmovmskb      eax, m5
+    test          r3d, r3d
+    jz .renorm
+    movzx         r3d, word [cdfq+r4*2]
+    pcmpeqw        m4, m4
+    mova           m5, m4
+    lea           r2d, [r3+80] ; only support n_symbols >= 4
+    shr           r2d, 4
+    cmp           r3d, 32
+    adc           r3d, 0
+    pavgw          m4, m2
+    pavgw          m5, m3
+    psubw          m4, m0
+    psubw          m0, m2
+    movd           m2, r2d
+    psubw          m5, m1
+    psubw          m1, m3
+    psraw          m4, m2
+    psraw          m5, m2
+    paddw          m0, m4
+    paddw          m1, m5
+    movu       [cdfq], m0
+    movu    [cdfq+16], m1
+    mov   [cdfq+r4*2], r3w
+.renorm:
+    tzcnt         eax, eax
+    mov            r4, [sq+msac.dif]
+    movzx         r1d, word [buf+rax*2]
+    movzx         r2d, word [buf+rax*2-2]
+%if WIN64
+    add           rsp, 48
+%endif
+    jmp m(msac_decode_symbol_adapt4).renorm2
+
+%endif
--- a/tests/checkasm/checkasm.c
+++ b/tests/checkasm/checkasm.c
@@ -62,6 +62,7 @@
     const char *name;
     void (*func)(void);
 } tests[] = {
+    { "msac", checkasm_check_msac },
 #if CONFIG_8BPC
     { "cdef_8bpc", checkasm_check_cdef_8bpc },
     { "ipred_8bpc", checkasm_check_ipred_8bpc },
--- a/tests/checkasm/checkasm.h
+++ b/tests/checkasm/checkasm.h
@@ -57,6 +57,7 @@
 name##_8bpc(void); \
 name##_16bpc(void)
 
+void checkasm_check_msac(void);
 decl_check_bitfns(void checkasm_check_cdef);
 decl_check_bitfns(void checkasm_check_ipred);
 decl_check_bitfns(void checkasm_check_itx);
--- /dev/null
+++ b/tests/checkasm/msac.c
@@ -1,0 +1,115 @@
+/*
+ * Copyright © 2019, VideoLAN and dav1d authors
+ * Copyright © 2019, Two Orioles, LLC
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ *    list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ *    this list of conditions and the following disclaimer in the documentation
+ *    and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#include "tests/checkasm/checkasm.h"
+
+#include "src/cpu.h"
+#include "src/msac.h"
+
+#include <string.h>
+
+/* The normal code doesn't use function pointers */
+typedef unsigned (*decode_symbol_adapt_fn)(MsacContext *s, uint16_t *cdf,
+                                           size_t n_symbols);
+
+typedef struct {
+    decode_symbol_adapt_fn symbol_adapt4;
+    decode_symbol_adapt_fn symbol_adapt8;
+    decode_symbol_adapt_fn symbol_adapt16;
+} MsacDSPContext;
+
+static void randomize_cdf(uint16_t *const cdf, int n) {
+    for (int i = 16; i > n; i--)
+        cdf[i] = rnd(); /* randomize padding */
+    cdf[n] = cdf[n-1] = 0;
+    while (--n > 0)
+        cdf[n-1] = cdf[n] + rnd() % (32768 - cdf[n] - n) + 1;
+}
+
+/* memcmp() on structs can have weird behavior due to padding etc. */
+static int msac_cmp(const MsacContext *const a, const MsacContext *const b) {
+    return a->buf_pos != b->buf_pos || a->buf_end != b->buf_end ||
+           a->dif != b->dif || a->rng != b->rng || a->cnt != b->cnt ||
+           a->allow_update_cdf != b->allow_update_cdf;
+}
+
+#define CHECK_SYMBOL_ADAPT(n, n_min, n_max) do {                           \
+    if (check_func(c->symbol_adapt##n, "msac_decode_symbol_adapt%d", n)) { \
+        for (int cdf_update = 0; cdf_update <= 1; cdf_update++) {          \
+            for (int ns = n_min; ns <= n_max; ns++) {                      \
+                dav1d_msac_init(&s_c, buf, sizeof(buf), !cdf_update);      \
+                s_a = s_c;                                                 \
+                randomize_cdf(cdf[0], ns);                                 \
+                memcpy(cdf[1], cdf[0], sizeof(*cdf));                      \
+                for (int i = 0; i < 64; i++) {                             \
+                    unsigned c_res = call_ref(&s_c, cdf[0], ns);           \
+                    unsigned a_res = call_new(&s_a, cdf[1], ns);           \
+                    if (c_res != a_res || msac_cmp(&s_c, &s_a) ||          \
+                        memcmp(cdf[0], cdf[1], sizeof(**cdf) * (ns + 1)))  \
+                    {                                                      \
+                        fail();                                            \
+                    }                                                      \
+                }                                                          \
+                if (cdf_update && ns == n)                                 \
+                    bench_new(&s_a, cdf[0], n);                            \
+            }                                                              \
+        }                                                                  \
+    }                                                                      \
+} while (0)
+
+static void check_decode_symbol_adapt(MsacDSPContext *const c) {
+    /* Use an aligned CDF buffer for more consistent benchmark
+     * results, and a misaligned one for checking correctness. */
+    ALIGN_STK_16(uint16_t, cdf, 2, [17]);
+    MsacContext s_c, s_a;
+    uint8_t buf[1024];
+    for (int i = 0; i < 1024; i++)
+        buf[i] = rnd();
+
+    declare_func(unsigned, MsacContext *s, uint16_t *cdf, size_t n_symbols);
+    CHECK_SYMBOL_ADAPT( 4, 1,  5);
+    CHECK_SYMBOL_ADAPT( 8, 1,  8);
+    CHECK_SYMBOL_ADAPT(16, 4, 16);
+    report("decode_symbol_adapt");
+}
+
+void checkasm_check_msac(void) {
+    MsacDSPContext c;
+    c.symbol_adapt4  = dav1d_msac_decode_symbol_adapt_c;
+    c.symbol_adapt8  = dav1d_msac_decode_symbol_adapt_c;
+    c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt_c;
+
+#if ARCH_X86_64 && HAVE_ASM
+    if (dav1d_get_cpu_flags() & DAV1D_X86_CPU_FLAG_SSE2) {
+        c.symbol_adapt4  = dav1d_msac_decode_symbol_adapt4_sse2;
+        c.symbol_adapt8  = dav1d_msac_decode_symbol_adapt8_sse2;
+        c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_sse2;
+    }
+#endif
+
+    check_decode_symbol_adapt(&c);
+}
--- a/tests/meson.build
+++ b/tests/meson.build
@@ -34,7 +34,10 @@
 libdav1d_nasm_objs_if_needed = []
 
 if is_asm_enabled
-    checkasm_sources = files('checkasm/checkasm.c')
+    checkasm_sources = files(
+        'checkasm/checkasm.c',
+        'checkasm/msac.c',
+    )
 
     checkasm_tmpl_sources = files(
         'checkasm/cdef.c',