shithub: dav1d

Download patch

ref: c3980e394d32ed832dfd65decde5f210c03b2f27
parent: 2e6c8a92d25234cb27651a76760fd2b50591bc51
author: Ronald S. Bultje <[email protected]>
date: Wed Dec 5 13:21:05 EST 2018

12 bits/component support

--- a/include/common/bitdepth.h
+++ b/include/common/bitdepth.h
@@ -34,6 +34,9 @@
 #if !defined(BITDEPTH)
 typedef void pixel;
 typedef void coef;
+#define HIGHBD_DECL_SUFFIX /* nothing */
+#define HIGHBD_CALL_SUFFIX /* nothing */
+#define HIGHBD_TAIL_SUFFIX /* nothing */
 #elif BITDEPTH == 8
 typedef uint8_t pixel;
 typedef int16_t coef;
@@ -41,28 +44,37 @@
 #define pixel_set memset
 #define iclip_pixel iclip_u8
 #define PIX_HEX_FMT "%02x"
-#define bytefn(x) x##_8bpc
 #define bitfn(x) x##_8bpc
 #define PXSTRIDE(x) x
-#elif BITDEPTH == 10 || BITDEPTH == 12
+#define highbd_only(x)
+#define HIGHBD_DECL_SUFFIX /* nothing */
+#define HIGHBD_CALL_SUFFIX /* nothing */
+#define HIGHBD_TAIL_SUFFIX /* nothing */
+#define bitdepth_from_max(x) 8
+#elif BITDEPTH == 16
 typedef uint16_t pixel;
 typedef int32_t coef;
 #define pixel_copy(a, b, c) memcpy(a, b, (c) << 1)
-#define iclip_pixel(x) iclip(x, 0, ((1 << BITDEPTH) - 1))
 static inline void pixel_set(pixel *const dst, const int val, const int num) {
     for (int n = 0; n < num; n++)
         dst[n] = val;
 }
 #define PIX_HEX_FMT "%03x"
-#define bytefn(x) x##_16bpc
-#if BITDEPTH == 10
-#define bitfn(x) x##_10bpc
-#else
-#define bitfn(x) x##_12bpc
-#endif
+#define iclip_pixel(x) iclip(x, 0, bitdepth_max)
+#define HIGHBD_DECL_SUFFIX , const int bitdepth_max
+#define HIGHBD_CALL_SUFFIX , f->bitdepth_max
+#define HIGHBD_TAIL_SUFFIX , bitdepth_max
+#define bitdepth_from_max(bitdepth_max) (32 - clz(bitdepth_max))
+#define bitfn(x) x##_16bpc
 #define PXSTRIDE(x) (x >> 1)
+#define highbd_only(x) x
 #else
 #error invalid value for bitdepth
 #endif
+#define bytefn(x) bitfn(x)
+
+#define bitfn_decls(name, ...) \
+name##_8bpc(__VA_ARGS__); \
+name##_16bpc(__VA_ARGS__)
 
 #endif /* __DAV1D_COMMON_BITDEPTH_H__ */
--- a/meson.build
+++ b/meson.build
@@ -55,7 +55,7 @@
 
 # Bitdepth option
 dav1d_bitdepths = get_option('bitdepths')
-foreach bitdepth : ['8', '10']
+foreach bitdepth : ['8', '16']
     cdata.set10('CONFIG_@0@BPC'.format(bitdepth), dav1d_bitdepths.contains(bitdepth))
 endforeach
 
--- a/meson_options.txt
+++ b/meson_options.txt
@@ -2,7 +2,7 @@
 
 option('bitdepths',
     type: 'array',
-    choices: ['8', '10'],
+    choices: ['8', '16'],
     description: 'Enable only specified bitdepths')
 
 option('build_asm',
--- a/src/cdef.h
+++ b/src/cdef.h
@@ -53,11 +53,11 @@
 #define decl_cdef_fn(name) \
 void (name)(pixel *dst, ptrdiff_t stride, const_left_pixel_row_2px left, \
             /*const*/ pixel *const top[2], int pri_strength, int sec_strength, \
-            int dir, int damping, enum CdefEdgeFlags edges)
+            int dir, int damping, enum CdefEdgeFlags edges HIGHBD_DECL_SUFFIX)
 typedef decl_cdef_fn(*cdef_fn);
 
 #define decl_cdef_dir_fn(name) \
-int (name)(const pixel *dst, ptrdiff_t dst_stride, unsigned *var)
+int (name)(const pixel *dst, ptrdiff_t dst_stride, unsigned *var HIGHBD_DECL_SUFFIX)
 typedef decl_cdef_dir_fn(*cdef_dir_fn);
 
 typedef struct Dav1dCdefDSPContext {
@@ -65,10 +65,7 @@
     cdef_fn fb[3 /* 444/luma, 422, 420 */];
 } Dav1dCdefDSPContext;
 
-void dav1d_cdef_dsp_init_8bpc(Dav1dCdefDSPContext *c);
-void dav1d_cdef_dsp_init_10bpc(Dav1dCdefDSPContext *c);
-
-void dav1d_cdef_dsp_init_x86_8bpc(Dav1dCdefDSPContext *c);
-void dav1d_cdef_dsp_init_x86_10bpc(Dav1dCdefDSPContext *c);
+bitfn_decls(void dav1d_cdef_dsp_init, Dav1dCdefDSPContext *c);
+bitfn_decls(void dav1d_cdef_dsp_init_x86, Dav1dCdefDSPContext *c);
 
 #endif /* __DAV1D_SRC_CDEF_H__ */
--- a/src/cdef_apply_tmpl.c
+++ b/src/cdef_apply_tmpl.c
@@ -83,12 +83,13 @@
                              const Av1Filter *const lflvl,
                              const int by_start, const int by_end)
 {
+    const int bitdepth_min_8 = BITDEPTH == 8 ? 0 : f->cur.p.bpc - 8;
     const Dav1dDSPContext *const dsp = f->dsp;
     enum CdefEdgeFlags edges = HAVE_BOTTOM | (by_start > 0 ? HAVE_TOP : 0);
     pixel *ptrs[3] = { p[0], p[1], p[2] };
     const int sbsz = 16;
     const int sb64w = f->sb128w << 1;
-    const int damping = f->frame_hdr->cdef.damping + BITDEPTH - 8;
+    const int damping = f->frame_hdr->cdef.damping + bitdepth_min_8;
     const enum Dav1dPixelLayout layout = f->cur.p.layout;
     const int uv_idx = DAV1D_PIXEL_LAYOUT_I444 - layout;
     const int has_chroma = layout != DAV1D_PIXEL_LAYOUT_I400;
@@ -156,17 +157,17 @@
                 }
 
                 // the actual filter
-                const int y_pri_lvl = (y_lvl >> 2) << (BITDEPTH - 8);
+                const int y_pri_lvl = (y_lvl >> 2) << bitdepth_min_8;
                 int y_sec_lvl = y_lvl & 3;
                 y_sec_lvl += y_sec_lvl == 3;
-                y_sec_lvl <<= BITDEPTH - 8;
-                const int uv_pri_lvl = (uv_lvl >> 2) << (BITDEPTH - 8);
+                y_sec_lvl <<= bitdepth_min_8;
+                const int uv_pri_lvl = (uv_lvl >> 2) << bitdepth_min_8;
                 int uv_sec_lvl = uv_lvl & 3;
                 uv_sec_lvl += uv_sec_lvl == 3;
-                uv_sec_lvl <<= BITDEPTH - 8;
+                uv_sec_lvl <<= bitdepth_min_8;
                 unsigned variance;
                 const int dir = dsp->cdef.dir(bptrs[0], f->cur.stride[0],
-                                              &variance);
+                                              &variance HIGHBD_CALL_SUFFIX);
                 if (y_lvl) {
                     dsp->cdef.fb[0](bptrs[0], f->cur.stride[0], lr_bak[bit][0],
                                     (pixel *const [2]) {
@@ -175,7 +176,7 @@
                                     },
                                     adjust_strength(y_pri_lvl, variance),
                                     y_sec_lvl, y_pri_lvl ? dir : 0,
-                                    damping, edges);
+                                    damping, edges HIGHBD_CALL_SUFFIX);
                 }
                 if (uv_lvl && has_chroma) {
                     const int uvdir =
@@ -190,7 +191,7 @@
                                              },
                                              uv_pri_lvl, uv_sec_lvl,
                                              uv_pri_lvl ? uvdir : 0,
-                                             damping - 1, edges);
+                                             damping - 1, edges HIGHBD_CALL_SUFFIX);
                     }
                 }
 
--- a/src/cdef_tmpl.c
+++ b/src/cdef_tmpl.c
@@ -97,7 +97,8 @@
                     const pixel (*left)[2], /*const*/ pixel *const top[2],
                     const int w, const int h, const int pri_strength,
                     const int sec_strength, const int dir,
-                    const int damping, const enum CdefEdgeFlags edges)
+                    const int damping, const enum CdefEdgeFlags edges
+                    HIGHBD_DECL_SUFFIX)
 {
     static const int8_t cdef_directions[8 /* dir */][2 /* pass */] = {
         { -1 * 12 + 1, -2 * 12 + 2 },
@@ -115,7 +116,8 @@
     assert((w == 4 || w == 8) && (h == 4 || h == 8));
     uint16_t tmp_buf[144];  // 12*12 is the maximum value of tmp_stride * (h + 4)
     uint16_t *tmp = tmp_buf + 2 * tmp_stride + 2;
-    const uint8_t *const pri_taps = cdef_pri_taps[(pri_strength >> (BITDEPTH - 8)) & 1];
+    const int bitdepth_min_8 = bitdepth_from_max(bitdepth_max) - 8;
+    const uint8_t *const pri_taps = cdef_pri_taps[(pri_strength >> bitdepth_min_8) & 1];
 
     padding(tmp, tmp_stride, dst, dst_stride, left, top, w, h, edges);
 
@@ -170,10 +172,11 @@
                                             const int sec_strength, \
                                             const int dir, \
                                             const int damping, \
-                                            const enum CdefEdgeFlags edges) \
+                                            const enum CdefEdgeFlags edges \
+                                            HIGHBD_DECL_SUFFIX) \
 { \
     cdef_filter_block_c(dst, stride, left, top, w, h, pri_strength, sec_strength, \
-                        dir, damping, edges); \
+                        dir, damping, edges HIGHBD_TAIL_SUFFIX); \
 }
 
 cdef_fn(4, 4);
@@ -181,8 +184,9 @@
 cdef_fn(8, 8);
 
 static int cdef_find_dir_c(const pixel *img, const ptrdiff_t stride,
-                           unsigned *const var)
+                           unsigned *const var HIGHBD_DECL_SUFFIX)
 {
+    const int bitdepth_min_8 = bitdepth_from_max(bitdepth_max) - 8;
     int partial_sum_hv[2][8] = { { 0 } };
     int partial_sum_diag[2][15] = { { 0 } };
     int partial_sum_alt[4][11] = { { 0 } };
@@ -189,7 +193,7 @@
 
     for (int y = 0; y < 8; y++) {
         for (int x = 0; x < 8; x++) {
-            const int px = (img[x] >> (BITDEPTH - 8)) - 128;
+            const int px = (img[x] >> bitdepth_min_8) - 128;
 
             partial_sum_diag[0][     y       +  x      ] += px;
             partial_sum_alt [0][     y       + (x >> 1)] += px;
--- a/src/decode.c
+++ b/src/decode.c
@@ -3013,7 +3013,6 @@
 
         switch (bpc) {
 #define assign_bitdepth_case(bd) \
-        case bd: \
             dav1d_cdef_dsp_init_##bd##bpc(&dsp->cdef); \
             dav1d_intra_pred_dsp_init_##bd##bpc(&dsp->ipred); \
             dav1d_itx_dsp_init_##bd##bpc(&dsp->itx); \
@@ -3022,10 +3021,13 @@
             dav1d_mc_dsp_init_##bd##bpc(&dsp->mc); \
             break
 #if CONFIG_8BPC
-        assign_bitdepth_case(8);
+        case 8:
+            assign_bitdepth_case(8);
 #endif
-#if CONFIG_10BPC
-        assign_bitdepth_case(10);
+#if CONFIG_16BPC
+        case 10:
+        case 12:
+            assign_bitdepth_case(16);
 #endif
 #undef assign_bitdepth_case
         default:
@@ -3047,7 +3049,7 @@
         assign_bitdepth_case(8);
 #endif
     } else {
-#if CONFIG_10BPC
+#if CONFIG_16BPC
         assign_bitdepth_case(16);
 #endif
     }
@@ -3168,6 +3170,7 @@
     f->sb_step = 16 << f->seq_hdr->sb128;
     f->sbh = (f->bh + f->sb_step - 1) >> f->sb_shift;
     f->b4_stride = (f->bw + 31) & ~31;
+    f->bitdepth_max = (1 << f->cur.p.bpc) - 1;
 
     // ref_mvs
     if ((f->frame_hdr->frame_type & 1) || f->frame_hdr->allow_intrabc) {
--- a/src/dequant_tables.c
+++ b/src/dequant_tables.c
@@ -160,5 +160,70 @@
         { 3586, 5916, }, { 3702, 6032, }, { 3823, 6148, }, { 3953, 6268, },
         { 4089, 6388, }, { 4236, 6512, }, { 4394, 6640, }, { 4559, 6768, },
         { 4737, 6900, }, { 4929, 7036, }, { 5130, 7172, }, { 5347, 7312, },
+    }, {
+        {     4,     4 }, {    12,    13 }, {    18,    19 }, {    25,    27 },
+        {    33,    35 }, {    41,    44 }, {    50,    54 }, {    60,    64 },
+        {    70,    75 }, {    80,    87 }, {    91,    99 }, {   103,   112 },
+        {   115,   126 }, {   127,   139 }, {   140,   154 }, {   153,   168 },
+        {   166,   183 }, {   180,   199 }, {   194,   214 }, {   208,   230 },
+        {   222,   247 }, {   237,   263 }, {   251,   280 }, {   266,   297 },
+        {   281,   314 }, {   296,   331 }, {   312,   349 }, {   327,   366 },
+        {   343,   384 }, {   358,   402 }, {   374,   420 }, {   390,   438 },
+        {   405,   456 }, {   421,   475 }, {   437,   493 }, {   453,   511 },
+        {   469,   530 }, {   484,   548 }, {   500,   567 }, {   516,   586 },
+        {   532,   604 }, {   548,   623 }, {   564,   642 }, {   580,   660 },
+        {   596,   679 }, {   611,   698 }, {   627,   716 }, {   643,   735 },
+        {   659,   753 }, {   674,   772 }, {   690,   791 }, {   706,   809 },
+        {   721,   828 }, {   737,   846 }, {   752,   865 }, {   768,   884 },
+        {   783,   902 }, {   798,   920 }, {   814,   939 }, {   829,   957 },
+        {   844,   976 }, {   859,   994 }, {   874,  1012 }, {   889,  1030 },
+        {   904,  1049 }, {   919,  1067 }, {   934,  1085 }, {   949,  1103 },
+        {   964,  1121 }, {   978,  1139 }, {   993,  1157 }, {  1008,  1175 },
+        {  1022,  1193 }, {  1037,  1211 }, {  1051,  1229 }, {  1065,  1246 },
+        {  1080,  1264 }, {  1094,  1282 }, {  1108,  1299 }, {  1122,  1317 },
+        {  1136,  1335 }, {  1151,  1352 }, {  1165,  1370 }, {  1179,  1387 },
+        {  1192,  1405 }, {  1206,  1422 }, {  1220,  1440 }, {  1234,  1457 },
+        {  1248,  1474 }, {  1261,  1491 }, {  1275,  1509 }, {  1288,  1526 },
+        {  1302,  1543 }, {  1315,  1560 }, {  1329,  1577 }, {  1342,  1595 },
+        {  1368,  1627 }, {  1393,  1660 }, {  1419,  1693 }, {  1444,  1725 },
+        {  1469,  1758 }, {  1494,  1791 }, {  1519,  1824 }, {  1544,  1856 },
+        {  1569,  1889 }, {  1594,  1922 }, {  1618,  1954 }, {  1643,  1987 },
+        {  1668,  2020 }, {  1692,  2052 }, {  1717,  2085 }, {  1741,  2118 },
+        {  1765,  2150 }, {  1789,  2183 }, {  1814,  2216 }, {  1838,  2248 },
+        {  1862,  2281 }, {  1885,  2313 }, {  1909,  2346 }, {  1933,  2378 },
+        {  1957,  2411 }, {  1992,  2459 }, {  2027,  2508 }, {  2061,  2556 },
+        {  2096,  2605 }, {  2130,  2653 }, {  2165,  2701 }, {  2199,  2750 },
+        {  2233,  2798 }, {  2267,  2847 }, {  2300,  2895 }, {  2334,  2943 },
+        {  2367,  2992 }, {  2400,  3040 }, {  2434,  3088 }, {  2467,  3137 },
+        {  2499,  3185 }, {  2532,  3234 }, {  2575,  3298 }, {  2618,  3362 },
+        {  2661,  3426 }, {  2704,  3491 }, {  2746,  3555 }, {  2788,  3619 },
+        {  2830,  3684 }, {  2872,  3748 }, {  2913,  3812 }, {  2954,  3876 },
+        {  2995,  3941 }, {  3036,  4005 }, {  3076,  4069 }, {  3127,  4149 },
+        {  3177,  4230 }, {  3226,  4310 }, {  3275,  4390 }, {  3324,  4470 },
+        {  3373,  4550 }, {  3421,  4631 }, {  3469,  4711 }, {  3517,  4791 },
+        {  3565,  4871 }, {  3621,  4967 }, {  3677,  5064 }, {  3733,  5160 },
+        {  3788,  5256 }, {  3843,  5352 }, {  3897,  5448 }, {  3951,  5544 },
+        {  4005,  5641 }, {  4058,  5737 }, {  4119,  5849 }, {  4181,  5961 },
+        {  4241,  6073 }, {  4301,  6185 }, {  4361,  6297 }, {  4420,  6410 },
+        {  4479,  6522 }, {  4546,  6650 }, {  4612,  6778 }, {  4677,  6906 },
+        {  4742,  7034 }, {  4807,  7162 }, {  4871,  7290 }, {  4942,  7435 },
+        {  5013,  7579 }, {  5083,  7723 }, {  5153,  7867 }, {  5222,  8011 },
+        {  5291,  8155 }, {  5367,  8315 }, {  5442,  8475 }, {  5517,  8635 },
+        {  5591,  8795 }, {  5665,  8956 }, {  5745,  9132 }, {  5825,  9308 },
+        {  5905,  9484 }, {  5984,  9660 }, {  6063,  9836 }, {  6149, 10028 },
+        {  6234, 10220 }, {  6319, 10412 }, {  6404, 10604 }, {  6495, 10812 },
+        {  6587, 11020 }, {  6678, 11228 }, {  6769, 11437 }, {  6867, 11661 },
+        {  6966, 11885 }, {  7064, 12109 }, {  7163, 12333 }, {  7269, 12573 },
+        {  7376, 12813 }, {  7483, 13053 }, {  7599, 13309 }, {  7715, 13565 },
+        {  7832, 13821 }, {  7958, 14093 }, {  8085, 14365 }, {  8214, 14637 },
+        {  8352, 14925 }, {  8492, 15213 }, {  8635, 15502 }, {  8788, 15806 },
+        {  8945, 16110 }, {  9104, 16414 }, {  9275, 16734 }, {  9450, 17054 },
+        {  9639, 17390 }, {  9832, 17726 }, { 10031, 18062 }, { 10245, 18414 },
+        { 10465, 18766 }, { 10702, 19134 }, { 10946, 19502 }, { 11210, 19886 },
+        { 11482, 20270 }, { 11776, 20670 }, { 12081, 21070 }, { 12409, 21486 },
+        { 12750, 21902 }, { 13118, 22334 }, { 13501, 22766 }, { 13913, 23214 },
+        { 14343, 23662 }, { 14807, 24126 }, { 15290, 24590 }, { 15812, 25070 },
+        { 16356, 25551 }, { 16943, 26047 }, { 17575, 26559 }, { 18237, 27071 },
+        { 18949, 27599 }, { 19718, 28143 }, { 20521, 28687 }, { 21387, 29247 },
     }
 };
--- a/src/film_grain.h
+++ b/src/film_grain.h
@@ -30,10 +30,7 @@
 
 #include "dav1d/dav1d.h"
 
-void dav1d_apply_grain_8bpc(Dav1dPicture *const out,
-                            const Dav1dPicture *const in);
-
-void dav1d_apply_grain_10bpc(Dav1dPicture *const out,
-                             const Dav1dPicture *const in);
+bitfn_decls(void dav1d_apply_grain, Dav1dPicture *const out,
+                                    const Dav1dPicture *const in);
 
 #endif /* __DAV1D_SRC_FILM_GRAIN_H__ */
--- a/src/film_grain_tmpl.c
+++ b/src/film_grain_tmpl.c
@@ -51,7 +51,11 @@
     SUB_GRAIN_HEIGHT = 38,
     SUB_GRAIN_OFFSET = 6,
     BLOCK_SIZE = 32,
-    SCALING_SIZE = 1 << BITDEPTH,
+#if BITDEPTH == 8
+    SCALING_SIZE = 256
+#else
+    SCALING_SIZE = 4096
+#endif
 };
 
 static inline int get_random_number(const int bits, unsigned *state) {
@@ -66,18 +70,14 @@
     return (x + ((1 << shift) >> 1)) >> shift;
 }
 
-enum {
-    GRAIN_CENTER = 128 << (BITDEPTH - 8),
-    GRAIN_MIN = -GRAIN_CENTER,
-    GRAIN_MAX = (256 << (BITDEPTH - 8)) - 1 - GRAIN_CENTER,
-};
-
 static void generate_grain_y(const Dav1dPicture *const in,
                              entry buf[GRAIN_HEIGHT][GRAIN_WIDTH])
 {
     const Dav1dFilmGrainData *data = &in->frame_hdr->film_grain.data;
     unsigned seed = data->seed;
-    const int shift = 12 - BITDEPTH + data->grain_scale_shift;
+    const int shift = 12 - in->p.bpc + data->grain_scale_shift;
+    const int grain_ctr = 128 << (in->p.bpc - 8);
+    const int grain_min = -grain_ctr, grain_max = grain_ctr - 1;
 
     for (int y = 0; y < GRAIN_HEIGHT; y++) {
         for (int x = 0; x < GRAIN_WIDTH; x++) {
@@ -102,7 +102,7 @@
             }
 
             int grain = buf[y][x] + round2(sum, data->ar_coeff_shift);
-            buf[y][x] = iclip(grain, GRAIN_MIN, GRAIN_MAX);
+            buf[y][x] = iclip(grain, grain_min, grain_max);
         }
     }
 }
@@ -113,7 +113,9 @@
 {
     const Dav1dFilmGrainData *data = &in->frame_hdr->film_grain.data;
     unsigned seed = data->seed ^ (uv ? 0x49d8 : 0xb524);
-    const int shift = 12 - BITDEPTH + data->grain_scale_shift;
+    const int shift = 12 - in->p.bpc + data->grain_scale_shift;
+    const int grain_ctr = 128 << (in->p.bpc - 8);
+    const int grain_min = -grain_ctr, grain_max = grain_ctr - 1;
 
     const int subx = in->p.layout != DAV1D_PIXEL_LAYOUT_I444;
     const int suby = in->p.layout == DAV1D_PIXEL_LAYOUT_I420;
@@ -160,15 +162,17 @@
             }
 
             const int grain = buf[y][x] + round2(sum, data->ar_coeff_shift);
-            buf[y][x] = iclip(grain, GRAIN_MIN, GRAIN_MAX);
+            buf[y][x] = iclip(grain, grain_min, grain_max);
         }
     }
 }
 
-static void generate_scaling(const uint8_t points[][2], int num,
+static void generate_scaling(const int bitdepth,
+                             const uint8_t points[][2], int num,
                              uint8_t scaling[SCALING_SIZE])
 {
-    const int shift_x = BITDEPTH - 8;
+    const int shift_x = bitdepth - 8;
+    const int scaling_size = 1 << bitdepth;
 
     // Fill up the preceding entries with the initial value
     for (int i = 0; i < points[0][0] << shift_x; i++)
@@ -190,7 +194,7 @@
     }
 
     // Fill up the remaining entries with the final value
-    for (int i = points[num - 1][0] << shift_x; i < SCALING_SIZE; i++)
+    for (int i = points[num - 1][0] << shift_x; i < scaling_size; i++)
         scaling[i] = points[num - 1][1];
 }
 
@@ -213,14 +217,17 @@
 {
     const Dav1dFilmGrainData *const data = &out->frame_hdr->film_grain.data;
     const int rows = 1 + (data->overlap_flag && row_num > 0);
+    const int bitdepth_min_8 = in->p.bpc - 8;
+    const int grain_ctr = 128 << bitdepth_min_8;
+    const int grain_min = -grain_ctr, grain_max = grain_ctr - 1;
 
     int min_value, max_value;
     if (data->clip_to_restricted_range) {
-        min_value = 16 << (BITDEPTH - 8);
-        max_value = 235 << (BITDEPTH - 8);
+        min_value = 16 << bitdepth_min_8;
+        max_value = 235 << bitdepth_min_8;
     } else {
         min_value = 0;
-        max_value = (1 << BITDEPTH) - 1;
+        max_value = (1U << in->p.bpc) - 1;
     }
 
     // seed[0] contains the current row, seed[1] contains the previous
@@ -278,7 +285,7 @@
                 int grain = sample_lut(grain_lut, offsets, 0, 0, 0, 0, x, y);
                 int old   = sample_lut(grain_lut, offsets, 0, 0, 1, 0, x, y);
                 grain = round2(old * w[x][0] + grain * w[x][1], 5);
-                grain = iclip(grain, GRAIN_MIN, GRAIN_MAX);
+                grain = iclip(grain, grain_min, grain_max);
                 add_noise_y(x, y, grain);
             }
         }
@@ -289,7 +296,7 @@
                 int grain = sample_lut(grain_lut, offsets, 0, 0, 0, 0, x, y);
                 int old   = sample_lut(grain_lut, offsets, 0, 0, 0, 1, x, y);
                 grain = round2(old * w[y][0] + grain * w[y][1], 5);
-                grain = iclip(grain, GRAIN_MIN, GRAIN_MAX);
+                grain = iclip(grain, grain_min, grain_max);
                 add_noise_y(x, y, grain);
             }
 
@@ -299,17 +306,17 @@
                 int top = sample_lut(grain_lut, offsets, 0, 0, 0, 1, x, y);
                 int old = sample_lut(grain_lut, offsets, 0, 0, 1, 1, x, y);
                 top = round2(old * w[x][0] + top * w[x][1], 5);
-                top = iclip(top, GRAIN_MIN, GRAIN_MAX);
+                top = iclip(top, grain_min, grain_max);
 
                 // Blend the current pixel with the left block
                 int grain = sample_lut(grain_lut, offsets, 0, 0, 0, 0, x, y);
                 old = sample_lut(grain_lut, offsets, 0, 0, 1, 0, x, y);
                 grain = round2(old * w[x][0] + grain * w[x][1], 5);
-                grain = iclip(grain, GRAIN_MIN, GRAIN_MAX);
+                grain = iclip(grain, grain_min, grain_max);
 
                 // Mix the row rows together and apply grain
                 grain = round2(top * w[y][0] + grain * w[y][1], 5);
-                grain = iclip(grain, GRAIN_MIN, GRAIN_MAX);
+                grain = iclip(grain, grain_min, grain_max);
                 add_noise_y(x, y, grain);
             }
         }
@@ -322,18 +329,22 @@
 {
     const Dav1dFilmGrainData *const data = &out->frame_hdr->film_grain.data;
     const int rows = 1 + (data->overlap_flag && row_num > 0);
+    const int bitdepth_max = (1 << in->p.bpc) - 1;
+    const int bitdepth_min_8 = in->p.bpc - 8;
+    const int grain_ctr = 128 << bitdepth_min_8;
+    const int grain_min = -grain_ctr, grain_max = grain_ctr - 1;
 
     int min_value, max_value;
     if (data->clip_to_restricted_range) {
-        min_value = 16 << (BITDEPTH - 8);
+        min_value = 16 << bitdepth_min_8;
         if (out->seq_hdr->mtrx == DAV1D_MC_IDENTITY) {
-            max_value = 235 << (BITDEPTH - 8);
+            max_value = 235 << bitdepth_min_8;
         } else {
-            max_value = 240 << (BITDEPTH - 8);
+            max_value = 240 << bitdepth_min_8;
         }
     } else {
         min_value = 0;
-        max_value = (1 << BITDEPTH) - 1;
+        max_value = bitdepth_max;
     }
 
     const int sx = in->p.layout != DAV1D_PIXEL_LAYOUT_I444;
@@ -396,7 +407,7 @@
                 int combined = avg * data->uv_luma_mult[uv] +                   \
                                *src * data->uv_mult[uv];                        \
                 val = iclip_pixel( (combined >> 6) +                            \
-                                   (data->uv_offset[uv] * (1 << (BITDEPTH - 8))) );   \
+                                   (data->uv_offset[uv] * (1 << bitdepth_min_8)) );   \
             }                                                                   \
                                                                                 \
             int noise = round2(scaling[ val ] * (grain), data->scaling_shift);  \
@@ -414,7 +425,7 @@
                 int grain = sample_lut(grain_lut, offsets, sx, sy, 0, 0, x, y);
                 int old   = sample_lut(grain_lut, offsets, sx, sy, 1, 0, x, y);
                 grain = (old * w[sx][x][0] + grain * w[sx][x][1] + 16) >> 5;
-                grain = iclip(grain, GRAIN_MIN, GRAIN_MAX);
+                grain = iclip(grain, grain_min, grain_max);
                 add_noise_uv(x, y, grain);
             }
         }
@@ -425,7 +436,7 @@
                 int grain = sample_lut(grain_lut, offsets, sx, sy, 0, 0, x, y);
                 int old   = sample_lut(grain_lut, offsets, sx, sy, 0, 1, x, y);
                 grain = (old * w[sy][y][0] + grain * w[sy][y][1] + 16) >> 5;
-                grain = iclip(grain, GRAIN_MIN, GRAIN_MAX);
+                grain = iclip(grain, grain_min, grain_max);
                 add_noise_uv(x, y, grain);
             }
 
@@ -435,17 +446,17 @@
                 int top = sample_lut(grain_lut, offsets, sx, sy, 0, 1, x, y);
                 int old = sample_lut(grain_lut, offsets, sx, sy, 1, 1, x, y);
                 top = (old * w[sx][x][0] + top * w[sx][x][1] + 16) >> 5;
-                top = iclip(top, GRAIN_MIN, GRAIN_MAX);
+                top = iclip(top, grain_min, grain_max);
 
                 // Blend the current pixel with the left block
                 int grain = sample_lut(grain_lut, offsets, sx, sy, 0, 0, x, y);
                 old = sample_lut(grain_lut, offsets, sx, sy, 1, 0, x, y);
                 grain = (old * w[sx][x][0] + grain * w[sx][x][1] + 16) >> 5;
-                grain = iclip(grain, GRAIN_MIN, GRAIN_MAX);
+                grain = iclip(grain, grain_min, grain_max);
 
                 // Mix the row rows together and apply to image
                 grain = (top * w[sy][y][0] + grain * w[sy][y][1] + 16) >> 5;
-                grain = iclip(grain, GRAIN_MIN, GRAIN_MAX);
+                grain = iclip(grain, grain_min, grain_max);
                 add_noise_uv(x, y, grain);
             }
         }
@@ -469,11 +480,11 @@
 
     // Generate scaling LUTs as needed
     if (data->num_y_points)
-        generate_scaling(data->y_points, data->num_y_points, scaling[0]);
+        generate_scaling(in->p.bpc, data->y_points, data->num_y_points, scaling[0]);
     if (data->num_uv_points[0])
-        generate_scaling(data->uv_points[0], data->num_uv_points[0], scaling[1]);
+        generate_scaling(in->p.bpc, data->uv_points[0], data->num_uv_points[0], scaling[1]);
     if (data->num_uv_points[1])
-        generate_scaling(data->uv_points[1], data->num_uv_points[1], scaling[2]);
+        generate_scaling(in->p.bpc, data->uv_points[1], data->num_uv_points[1], scaling[2]);
 
     // Copy over the non-modified planes
     // TODO: eliminate in favor of per-plane refs
--- a/src/internal.h
+++ b/src/internal.h
@@ -176,6 +176,7 @@
     int a_sz /* w*tile_rows */;
     AV1_COMMON *libaom_cm; // FIXME
     uint8_t jnt_weights[7][7];
+    int bitdepth_max;
 
     struct {
         struct thread_data td;
--- a/src/ipred.h
+++ b/src/ipred.h
@@ -43,7 +43,8 @@
  */
 #define decl_angular_ipred_fn(name) \
 void (name)(pixel *dst, ptrdiff_t stride, const pixel *topleft, \
-            int width, int height, int angle, int max_width, int max_height)
+            int width, int height, int angle, int max_width, int max_height \
+            HIGHBD_DECL_SUFFIX)
 typedef decl_angular_ipred_fn(*angular_ipred_fn);
 
 /*
@@ -63,7 +64,8 @@
  */
 #define decl_cfl_pred_fn(name) \
 void (name)(pixel *dst, ptrdiff_t stride, const pixel *topleft, \
-            int width, int height, const int16_t *ac, int alpha)
+            int width, int height, const int16_t *ac, int alpha \
+            HIGHBD_DECL_SUFFIX)
 typedef decl_cfl_pred_fn(*cfl_pred_fn);
 
 /*
@@ -86,10 +88,7 @@
     pal_pred_fn pal_pred;
 } Dav1dIntraPredDSPContext;
 
-void dav1d_intra_pred_dsp_init_8bpc(Dav1dIntraPredDSPContext *c);
-void dav1d_intra_pred_dsp_init_10bpc(Dav1dIntraPredDSPContext *c);
-
-void dav1d_intra_pred_dsp_init_x86_8bpc(Dav1dIntraPredDSPContext *c);
-void dav1d_intra_pred_dsp_init_x86_10bpc(Dav1dIntraPredDSPContext *c);
+bitfn_decls(void dav1d_intra_pred_dsp_init, Dav1dIntraPredDSPContext *c);
+bitfn_decls(void dav1d_intra_pred_dsp_init_x86, Dav1dIntraPredDSPContext *c);
 
 #endif /* __DAV1D_SRC_IPRED_H__ */
--- a/src/ipred_prepare.h
+++ b/src/ipred_prepare.h
@@ -81,7 +81,8 @@
                                       const pixel *dst, ptrdiff_t stride,
                                       const pixel *prefilter_toplevel_sb_edge,
                                       enum IntraPredMode mode, int *angle,
-                                      int tw, int th, pixel *topleft_out);
+                                      int tw, int th, pixel *topleft_out
+                                      HIGHBD_DECL_SUFFIX);
 
 // These flags are OR'd with the angle argument into intra predictors.
 // ANGLE_USE_EDGE_FILTER_FLAG signals that edges should be convolved
--- a/src/ipred_prepare_tmpl.c
+++ b/src/ipred_prepare_tmpl.c
@@ -83,8 +83,9 @@
                                   const pixel *prefilter_toplevel_sb_edge,
                                   enum IntraPredMode mode, int *const angle,
                                   const int tw, const int th,
-                                  pixel *const topleft_out)
+                                  pixel *const topleft_out HIGHBD_DECL_SUFFIX)
 {
+    const int bitdepth = bitdepth_from_max(bitdepth_max);
     assert(y < h && x < w);
 
     switch (mode) {
@@ -144,7 +145,7 @@
             if (px_have < sz)
                 pixel_set(left, left[sz - px_have], sz - px_have);
         } else {
-            pixel_set(left, have_top ? *dst_top : ((1 << BITDEPTH) >> 1) + 1, sz);
+            pixel_set(left, have_top ? *dst_top : ((1 << bitdepth) >> 1) + 1, sz);
         }
 
         if (av1_intra_prediction_edges[mode].needs_bottomleft) {
@@ -174,7 +175,7 @@
             if (px_have < sz)
                 pixel_set(top + px_have, top[px_have - 1], sz - px_have);
         } else {
-            pixel_set(top, have_left ? dst[-1] : ((1 << BITDEPTH) >> 1) - 1, sz);
+            pixel_set(top, have_left ? dst[-1] : ((1 << bitdepth) >> 1) - 1, sz);
         }
 
         if (av1_intra_prediction_edges[mode].needs_topright) {
@@ -198,7 +199,7 @@
         if (have_left) {
             *topleft_out = have_top ? dst_top[-1] : dst[-1];
         } else {
-            *topleft_out = have_top ? *dst_top : (1 << BITDEPTH) >> 1;
+            *topleft_out = have_top ? *dst_top : (1 << bitdepth) >> 1;
         }
         if (mode == Z2_PRED && tw + th >= 6)
             *topleft_out = (topleft_out[-1] * 5 + topleft_out[0] * 6 +
--- a/src/ipred_tmpl.c
+++ b/src/ipred_tmpl.c
@@ -39,10 +39,10 @@
 
 static NOINLINE void
 splat_dc(pixel *dst, const ptrdiff_t stride,
-         const int width, const int height, const unsigned dc)
+         const int width, const int height, const int dc HIGHBD_DECL_SUFFIX)
 {
-    assert(dc <= (1 << BITDEPTH) - 1);
 #if BITDEPTH == 8
+    assert(dc <= 0xff);
     if (width > 4) {
         const uint64_t dcN = dc * 0x0101010101010101ULL;
         for (int y = 0; y < height; y++) {
@@ -59,6 +59,7 @@
         }
     }
 #else
+    assert(dc <= bitdepth_max);
     const uint64_t dcN = dc * 0x0001000100010001ULL;
     for (int y = 0; y < height; y++) {
         for (int x = 0; x < width; x += sizeof(dcN) >> 1)
@@ -70,8 +71,8 @@
 
 static NOINLINE void
 cfl_pred(pixel *dst, const ptrdiff_t stride,
-         const int width, const int height, const unsigned dc,
-         const int16_t *ac, const int alpha)
+         const int width, const int height, const int dc,
+         const int16_t *ac, const int alpha HIGHBD_DECL_SUFFIX)
 {
     for (int y = 0; y < height; y++) {
         for (int x = 0; x < width; x++) {
@@ -93,17 +94,21 @@
 static void ipred_dc_top_c(pixel *dst, const ptrdiff_t stride,
                            const pixel *const topleft,
                            const int width, const int height, const int a,
-                           const int max_width, const int max_height)
+                           const int max_width, const int max_height
+                           HIGHBD_DECL_SUFFIX)
 {
-    splat_dc(dst, stride, width, height, dc_gen_top(topleft, width));
+    splat_dc(dst, stride, width, height, dc_gen_top(topleft, width)
+             HIGHBD_TAIL_SUFFIX);
 }
 
 static void ipred_cfl_top_c(pixel *dst, const ptrdiff_t stride,
                             const pixel *const topleft,
                             const int width, const int height,
-                            const int16_t *ac, const int alpha)
+                            const int16_t *ac, const int alpha
+                            HIGHBD_DECL_SUFFIX)
 {
-    cfl_pred(dst, stride, width, height, dc_gen_top(topleft, width), ac, alpha);
+    cfl_pred(dst, stride, width, height, dc_gen_top(topleft, width), ac, alpha
+             HIGHBD_TAIL_SUFFIX);
 }
 
 static unsigned dc_gen_left(const pixel *const topleft, const int height) {
@@ -116,18 +121,21 @@
 static void ipred_dc_left_c(pixel *dst, const ptrdiff_t stride,
                             const pixel *const topleft,
                             const int width, const int height, const int a,
-                            const int max_width, const int max_height)
+                            const int max_width, const int max_height
+                            HIGHBD_DECL_SUFFIX)
 {
-    splat_dc(dst, stride, width, height, dc_gen_left(topleft, height));
+    splat_dc(dst, stride, width, height, dc_gen_left(topleft, height)
+             HIGHBD_TAIL_SUFFIX);
 }
 
 static void ipred_cfl_left_c(pixel *dst, const ptrdiff_t stride,
                              const pixel *const topleft,
                              const int width, const int height,
-                             const int16_t *ac, const int alpha)
+                             const int16_t *ac, const int alpha
+                             HIGHBD_DECL_SUFFIX)
 {
     unsigned dc = dc_gen_left(topleft, height);
-    cfl_pred(dst, stride, width, height, dc, ac, alpha);
+    cfl_pred(dst, stride, width, height, dc, ac, alpha HIGHBD_TAIL_SUFFIX);
 }
 
 #if BITDEPTH == 8
@@ -161,18 +169,21 @@
 static void ipred_dc_c(pixel *dst, const ptrdiff_t stride,
                        const pixel *const topleft,
                        const int width, const int height, const int a,
-                       const int max_width, const int max_height)
+                       const int max_width, const int max_height
+                       HIGHBD_DECL_SUFFIX)
 {
-    splat_dc(dst, stride, width, height, dc_gen(topleft, width, height));
+    splat_dc(dst, stride, width, height, dc_gen(topleft, width, height)
+             HIGHBD_TAIL_SUFFIX);
 }
 
 static void ipred_cfl_c(pixel *dst, const ptrdiff_t stride,
                         const pixel *const topleft,
                         const int width, const int height,
-                        const int16_t *ac, const int alpha)
+                        const int16_t *ac, const int alpha
+                        HIGHBD_DECL_SUFFIX)
 {
     unsigned dc = dc_gen(topleft, width, height);
-    cfl_pred(dst, stride, width, height, dc, ac, alpha);
+    cfl_pred(dst, stride, width, height, dc, ac, alpha HIGHBD_TAIL_SUFFIX);
 }
 
 #undef MULTIPLIER_1x2
@@ -182,23 +193,36 @@
 static void ipred_dc_128_c(pixel *dst, const ptrdiff_t stride,
                            const pixel *const topleft,
                            const int width, const int height, const int a,
-                           const int max_width, const int max_height)
+                           const int max_width, const int max_height
+                           HIGHBD_DECL_SUFFIX)
 {
-    splat_dc(dst, stride, width, height, 1 << (BITDEPTH - 1));
+#if BITDEPTH == 16
+    const int dc = (bitdepth_max + 1) >> 1;
+#else
+    const int dc = 128;
+#endif
+    splat_dc(dst, stride, width, height, dc HIGHBD_TAIL_SUFFIX);
 }
 
 static void ipred_cfl_128_c(pixel *dst, const ptrdiff_t stride,
                             const pixel *const topleft,
                             const int width, const int height,
-                            const int16_t *ac, const int alpha)
+                            const int16_t *ac, const int alpha
+                            HIGHBD_DECL_SUFFIX)
 {
-    cfl_pred(dst, stride, width, height, 1 << (BITDEPTH - 1), ac, alpha);
+#if BITDEPTH == 16
+    const int dc = (bitdepth_max + 1) >> 1;
+#else
+    const int dc = 128;
+#endif
+    cfl_pred(dst, stride, width, height, dc, ac, alpha HIGHBD_TAIL_SUFFIX);
 }
 
 static void ipred_v_c(pixel *dst, const ptrdiff_t stride,
                       const pixel *const topleft,
                       const int width, const int height, const int a,
-                      const int max_width, const int max_height)
+                      const int max_width, const int max_height
+                      HIGHBD_DECL_SUFFIX)
 {
     for (int y = 0; y < height; y++) {
         pixel_copy(dst, topleft + 1, width);
@@ -209,7 +233,8 @@
 static void ipred_h_c(pixel *dst, const ptrdiff_t stride,
                       const pixel *const topleft,
                       const int width, const int height, const int a,
-                      const int max_width, const int max_height)
+                      const int max_width, const int max_height
+                      HIGHBD_DECL_SUFFIX)
 {
     for (int y = 0; y < height; y++) {
         pixel_set(dst, topleft[-(1 + y)], width);
@@ -220,7 +245,8 @@
 static void ipred_paeth_c(pixel *dst, const ptrdiff_t stride,
                           const pixel *const tl_ptr,
                           const int width, const int height, const int a,
-                          const int max_width, const int max_height)
+                          const int max_width, const int max_height
+                          HIGHBD_DECL_SUFFIX)
 {
     const int topleft = tl_ptr[0];
     for (int y = 0; y < height; y++) {
@@ -242,7 +268,8 @@
 static void ipred_smooth_c(pixel *dst, const ptrdiff_t stride,
                            const pixel *const topleft,
                            const int width, const int height, const int a,
-                           const int max_width, const int max_height)
+                           const int max_width, const int max_height
+                           HIGHBD_DECL_SUFFIX)
 {
     const uint8_t *const weights_hor = &dav1d_sm_weights[width];
     const uint8_t *const weights_ver = &dav1d_sm_weights[height];
@@ -263,7 +290,8 @@
 static void ipred_smooth_v_c(pixel *dst, const ptrdiff_t stride,
                              const pixel *const topleft,
                              const int width, const int height, const int a,
-                             const int max_width, const int max_height)
+                             const int max_width, const int max_height
+                             HIGHBD_DECL_SUFFIX)
 {
     const uint8_t *const weights_ver = &dav1d_sm_weights[height];
     const int bottom = topleft[-height];
@@ -281,7 +309,8 @@
 static void ipred_smooth_h_c(pixel *dst, const ptrdiff_t stride,
                              const pixel *const topleft,
                              const int width, const int height, const int a,
-                             const int max_width, const int max_height)
+                             const int max_width, const int max_height
+                             HIGHBD_DECL_SUFFIX)
 {
     const uint8_t *const weights_hor = &dav1d_sm_weights[width];
     const int right = topleft[width];
@@ -367,7 +396,8 @@
 }
 
 static void upsample_edge(pixel *const out, const int hsz,
-                          const pixel *const in, const int from, const int to)
+                          const pixel *const in, const int from, const int to
+                          HIGHBD_DECL_SUFFIX)
 {
     static const int8_t kernel[4] = { -1, 9, 9, -1 };
     int i;
@@ -385,7 +415,8 @@
 static void ipred_z1_c(pixel *dst, const ptrdiff_t stride,
                        const pixel *const topleft_in,
                        const int width, const int height, int angle,
-                       const int max_width, const int max_height)
+                       const int max_width, const int max_height
+                       HIGHBD_DECL_SUFFIX)
 {
     const int is_sm = (angle >> 9) & 0x1;
     const int enable_intra_edge_filter = angle >> 10;
@@ -398,8 +429,8 @@
     const int upsample_above = enable_intra_edge_filter ?
         get_upsample(width + height, 90 - angle, is_sm) : 0;
     if (upsample_above) {
-        upsample_edge(top_out, width + height,
-                      &topleft_in[1], -1, width + imin(width, height));
+        upsample_edge(top_out, width + height, &topleft_in[1], -1,
+                      width + imin(width, height) HIGHBD_TAIL_SUFFIX);
         top = top_out;
         max_base_x = 2 * (width + height) - 2;
         dx <<= 1;
@@ -438,7 +469,8 @@
 static void ipred_z2_c(pixel *dst, const ptrdiff_t stride,
                        const pixel *const topleft_in,
                        const int width, const int height, int angle,
-                       const int max_width, const int max_height)
+                       const int max_width, const int max_height
+                       HIGHBD_DECL_SUFFIX)
 {
     const int is_sm = (angle >> 9) & 0x1;
     const int enable_intra_edge_filter = angle >> 10;
@@ -454,7 +486,8 @@
     pixel *const topleft = &edge[height * 2];
 
     if (upsample_above) {
-        upsample_edge(topleft, width + 1, topleft_in, 0, width + 1);
+        upsample_edge(topleft, width + 1, topleft_in, 0, width + 1
+                      HIGHBD_TAIL_SUFFIX);
         dx <<= 1;
     } else {
         const int filter_strength = enable_intra_edge_filter ?
@@ -469,7 +502,8 @@
         }
     }
     if (upsample_left) {
-        upsample_edge(edge, height + 1, &topleft_in[-height], 0, height + 1);
+        upsample_edge(edge, height + 1, &topleft_in[-height], 0, height + 1
+                      HIGHBD_TAIL_SUFFIX);
         dy <<= 1;
     } else {
         const int filter_strength = enable_intra_edge_filter ?
@@ -516,7 +550,8 @@
 static void ipred_z3_c(pixel *dst, const ptrdiff_t stride,
                        const pixel *const topleft_in,
                        const int width, const int height, int angle,
-                       const int max_width, const int max_height)
+                       const int max_width, const int max_height
+                       HIGHBD_DECL_SUFFIX)
 {
     const int is_sm = (angle >> 9) & 0x1;
     const int enable_intra_edge_filter = angle >> 10;
@@ -531,7 +566,8 @@
     if (upsample_left) {
         upsample_edge(left_out, width + height,
                       &topleft_in[-(width + height)],
-                      imax(width - height, 0), width + height + 1);
+                      imax(width - height, 0), width + height + 1
+                      HIGHBD_TAIL_SUFFIX);
         left = &left_out[2 * (width + height) - 2];
         max_base_y = 2 * (width + height) - 2;
         dy <<= 1;
@@ -574,7 +610,8 @@
 static void ipred_filter_c(pixel *dst, const ptrdiff_t stride,
                            const pixel *const topleft_in,
                            const int width, const int height, int filt_idx,
-                           const int max_width, const int max_height)
+                           const int max_width, const int max_height
+                           HIGHBD_DECL_SUFFIX)
 {
     filt_idx &= 511;
     assert(filt_idx < 5);
--- a/src/itx.h
+++ b/src/itx.h
@@ -35,7 +35,8 @@
 #include "src/levels.h"
 
 #define decl_itx_fn(name) \
-void (name)(pixel *dst, ptrdiff_t dst_stride, coef *coeff, int eob)
+void (name)(pixel *dst, ptrdiff_t dst_stride, coef *coeff, int eob \
+            HIGHBD_DECL_SUFFIX)
 typedef decl_itx_fn(*itxfm_fn);
 
 typedef struct Dav1dInvTxfmDSPContext {
@@ -42,10 +43,7 @@
     itxfm_fn itxfm_add[N_RECT_TX_SIZES][N_TX_TYPES_PLUS_LL];
 } Dav1dInvTxfmDSPContext;
 
-void dav1d_itx_dsp_init_8bpc(Dav1dInvTxfmDSPContext *c);
-void dav1d_itx_dsp_init_10bpc(Dav1dInvTxfmDSPContext *c);
-
-void dav1d_itx_dsp_init_x86_8bpc(Dav1dInvTxfmDSPContext *c);
-void dav1d_itx_dsp_init_x86_10bpc(Dav1dInvTxfmDSPContext *c);
+bitfn_decls(void dav1d_itx_dsp_init, Dav1dInvTxfmDSPContext *c);
+bitfn_decls(void dav1d_itx_dsp_init_x86, Dav1dInvTxfmDSPContext *c);
 
 #endif /* __DAV1D_SRC_ITX_H__ */
--- a/src/itx_tmpl.c
+++ b/src/itx_tmpl.c
@@ -46,7 +46,8 @@
 inv_txfm_add_c(pixel *dst, const ptrdiff_t stride,
                coef *const coeff, const int eob,
                const int w, const int h, const int shift1, const int shift2,
-               const itx_1d_fn first_1d_fn, const itx_1d_fn second_1d_fn)
+               const itx_1d_fn first_1d_fn, const itx_1d_fn second_1d_fn
+               HIGHBD_DECL_SUFFIX)
 {
     int i, j;
     const ptrdiff_t sh = imin(h, 32), sw = imin(w, 32);
@@ -54,8 +55,9 @@
     // Maximum value for h and w is 64
     coef tmp[4096 /* w * h */], out[64 /* h */], in_mem[64 /* w */];
     const int is_rect2 = w * 2 == h || h * 2 == w;
-    const int row_clip_max = (1 << (BITDEPTH + 8 - 1)) - 1;
-    const int col_clip_max = (1 << (imax(BITDEPTH + 6, 16) - 1)) -1;
+    const int bitdepth = bitdepth_from_max(bitdepth_max);
+    const int row_clip_max = (1 << (bitdepth + 8 - 1)) - 1;
+    const int col_clip_max = (1 << (imax(bitdepth + 6, 16) - 1)) -1;
     const int col_clip_min = -col_clip_max - 1;
 
     if (w != sw) memset(&in_mem[sw], 0, (w - sw) * sizeof(*in_mem));
@@ -93,10 +95,12 @@
 inv_txfm_add_##type1##_##type2##_##w##x##h##_c(pixel *dst, \
                                                const ptrdiff_t stride, \
                                                coef *const coeff, \
-                                               const int eob) \
+                                               const int eob \
+                                               HIGHBD_DECL_SUFFIX) \
 { \
     inv_txfm_add_c(dst, stride, coeff, eob, w, h, shift1, shift2, \
-                   inv_##type1##w##_1d, inv_##type2##h##_1d); \
+                   inv_##type1##w##_1d, inv_##type2##h##_1d \
+                   HIGHBD_TAIL_SUFFIX); \
 }
 
 #define inv_txfm_fn64(w, h, shift1, shift2) \
@@ -147,9 +151,11 @@
 inv_txfm_fn64(64, 64, 2, 4)
 
 static void inv_txfm_add_wht_wht_4x4_c(pixel *dst, const ptrdiff_t stride,
-                                       coef *const coeff, const int eob)
+                                       coef *const coeff, const int eob
+                                       HIGHBD_DECL_SUFFIX)
 {
-    const int col_clip_max = (1 << (imax(BITDEPTH + 6, 16) - 1)) -1;
+    const int bitdepth = bitdepth_from_max(bitdepth_max);
+    const int col_clip_max = (1 << (imax(bitdepth + 6, 16) - 1)) -1;
     const int col_clip_min = -col_clip_max - 1;
     coef tmp[4 * 4], out[4];
 
--- a/src/lf_apply_tmpl.c
+++ b/src/lf_apply_tmpl.c
@@ -66,7 +66,7 @@
         hmask[3] = 0;
         dsp->lf.loop_filter_sb[0][0](&dst[x * 4], ls, hmask,
                                      (const uint8_t(*)[4]) &lvl[x][0], b4_stride,
-                                     &f->lf.lim_lut, endy4 - starty4);
+                                     &f->lf.lim_lut, endy4 - starty4 HIGHBD_CALL_SUFFIX);
     }
 }
 
@@ -96,7 +96,7 @@
         };
         dsp->lf.loop_filter_sb[0][1](dst, ls, vmask,
                                      (const uint8_t(*)[4]) &lvl[0][1], b4_stride,
-                                     &f->lf.lim_lut, w);
+                                     &f->lf.lim_lut, w HIGHBD_CALL_SUFFIX);
     }
 }
 
@@ -130,10 +130,10 @@
         hmask[2] = 0;
         dsp->lf.loop_filter_sb[1][0](&u[x * 4], ls, hmask,
                                      (const uint8_t(*)[4]) &lvl[x][2], b4_stride,
-                                     &f->lf.lim_lut, endy4 - starty4);
+                                     &f->lf.lim_lut, endy4 - starty4 HIGHBD_CALL_SUFFIX);
         dsp->lf.loop_filter_sb[1][0](&v[x * 4], ls, hmask,
                                      (const uint8_t(*)[4]) &lvl[x][3], b4_stride,
-                                     &f->lf.lim_lut, endy4 - starty4);
+                                     &f->lf.lim_lut, endy4 - starty4 HIGHBD_CALL_SUFFIX);
     }
 }
 
@@ -164,10 +164,10 @@
         };
         dsp->lf.loop_filter_sb[1][1](&u[off_l], ls, vmask,
                                      (const uint8_t(*)[4]) &lvl[0][2], b4_stride,
-                                     &f->lf.lim_lut, w);
+                                     &f->lf.lim_lut, w HIGHBD_CALL_SUFFIX);
         dsp->lf.loop_filter_sb[1][1](&v[off_l], ls, vmask,
                                      (const uint8_t(*)[4]) &lvl[0][3], b4_stride,
-                                     &f->lf.lim_lut, w);
+                                     &f->lf.lim_lut, w HIGHBD_CALL_SUFFIX);
     }
 }
 
--- a/src/lib.c
+++ b/src/lib.c
@@ -264,9 +264,10 @@
         dav1d_apply_grain_8bpc(out, in);
         break;
 #endif
-#if CONFIG_10BPC
+#if CONFIG_16BPC
     case 10:
-        dav1d_apply_grain_10bpc(out, in);
+    case 12:
+        dav1d_apply_grain_16bpc(out, in);
         break;
 #endif
     default:
--- a/src/loopfilter.h
+++ b/src/loopfilter.h
@@ -39,7 +39,7 @@
 #define decl_loopfilter_sb_fn(name) \
 void (name)(pixel *dst, ptrdiff_t stride, const uint32_t *mask, \
             const uint8_t (*lvl)[4], ptrdiff_t lvl_stride, \
-            const Av1FilterLUT *lut, int w)
+            const Av1FilterLUT *lut, int w HIGHBD_DECL_SUFFIX)
 typedef decl_loopfilter_sb_fn(*loopfilter_sb_fn);
 
 typedef struct Dav1dLoopFilterDSPContext {
@@ -52,10 +52,7 @@
     loopfilter_sb_fn loop_filter_sb[2][2];
 } Dav1dLoopFilterDSPContext;
 
-void dav1d_loop_filter_dsp_init_8bpc(Dav1dLoopFilterDSPContext *c);
-void dav1d_loop_filter_dsp_init_10bpc(Dav1dLoopFilterDSPContext *c);
-
-void dav1d_loop_filter_dsp_init_x86_8bpc(Dav1dLoopFilterDSPContext *c);
-void dav1d_loop_filter_dsp_init_x86_10bpc(Dav1dLoopFilterDSPContext *c);
+bitfn_decls(void dav1d_loop_filter_dsp_init, Dav1dLoopFilterDSPContext *c);
+bitfn_decls(void dav1d_loop_filter_dsp_init_x86, Dav1dLoopFilterDSPContext *c);
 
 #endif /* __DAV1D_SRC_LOOPFILTER_H__ */
--- a/src/loopfilter_tmpl.c
+++ b/src/loopfilter_tmpl.c
@@ -36,12 +36,14 @@
 
 static NOINLINE void
 loop_filter(pixel *dst, int E, int I, int H,
-            const ptrdiff_t stridea, const ptrdiff_t strideb, const int wd)
+            const ptrdiff_t stridea, const ptrdiff_t strideb, const int wd
+            HIGHBD_DECL_SUFFIX)
 {
-    const int F = 1 << (BITDEPTH - 8);
-    E <<= BITDEPTH - 8;
-    I <<= BITDEPTH - 8;
-    H <<= BITDEPTH - 8;
+    const int bitdepth_min_8 = bitdepth_from_max(bitdepth_max) - 8;
+    const int F = 1 << bitdepth_min_8;
+    E <<= bitdepth_min_8;
+    I <<= bitdepth_min_8;
+    H <<= bitdepth_min_8;
 
     for (int i = 0; i < 4; i++, dst += stridea) {
         int p6, p5, p4, p3, p2;
@@ -128,15 +130,15 @@
         } else {
             const int hev = abs(p1 - p0) > H || abs(q1 - q0) > H;
 
-#define iclip_diff(v) iclip(v, -128 * (1 << (BITDEPTH - 8)), \
-                                128 * (1 << (BITDEPTH - 8)) - 1)
+#define iclip_diff(v) iclip(v, -128 * (1 << bitdepth_min_8), \
+                                128 * (1 << bitdepth_min_8) - 1)
 
             if (hev) {
                 int f = iclip_diff(p1 - q1), f1, f2;
                 f = iclip_diff(3 * (q0 - p0) + f);
 
-                f1 = imin(f + 4, (128 << (BITDEPTH - 8)) - 1) >> 3;
-                f2 = imin(f + 3, (128 << (BITDEPTH - 8)) - 1) >> 3;
+                f1 = imin(f + 4, (128 << bitdepth_min_8) - 1) >> 3;
+                f2 = imin(f + 3, (128 << bitdepth_min_8) - 1) >> 3;
 
                 dst[strideb * -1] = iclip_pixel(p0 + f2);
                 dst[strideb * +0] = iclip_pixel(q0 - f1);
@@ -143,8 +145,8 @@
             } else {
                 int f = iclip_diff(3 * (q0 - p0)), f1, f2;
 
-                f1 = imin(f + 4, (128 << (BITDEPTH - 8)) - 1) >> 3;
-                f2 = imin(f + 3, (128 << (BITDEPTH - 8)) - 1) >> 3;
+                f1 = imin(f + 4, (128 << bitdepth_min_8) - 1) >> 3;
+                f2 = imin(f + 3, (128 << bitdepth_min_8) - 1) >> 3;
 
                 dst[strideb * -1] = iclip_pixel(p0 + f2);
                 dst[strideb * +0] = iclip_pixel(q0 - f1);
@@ -161,7 +163,8 @@
 static void loop_filter_h_sb128y_c(pixel *dst, const ptrdiff_t stride,
                                    const uint32_t *const vmask,
                                    const uint8_t (*l)[4], ptrdiff_t b4_stride,
-                                   const Av1FilterLUT *lut, const int h)
+                                   const Av1FilterLUT *lut, const int h
+                                   HIGHBD_DECL_SUFFIX)
 {
     const unsigned vm = vmask[0] | vmask[1] | vmask[2];
     for (unsigned y = 1; vm & ~(y - 1);
@@ -173,7 +176,8 @@
             const int H = L >> 4;
             const int E = lut->e[L], I = lut->i[L];
             const int idx = (vmask[2] & y) ? 2 : !!(vmask[1] & y);
-            loop_filter(dst, E, I, H, PXSTRIDE(stride), 1, 4 << idx);
+            loop_filter(dst, E, I, H, PXSTRIDE(stride), 1, 4 << idx
+                        HIGHBD_TAIL_SUFFIX);
         }
     }
 }
@@ -181,7 +185,8 @@
 static void loop_filter_v_sb128y_c(pixel *dst, const ptrdiff_t stride,
                                    const uint32_t *const vmask,
                                    const uint8_t (*l)[4], ptrdiff_t b4_stride,
-                                   const Av1FilterLUT *lut, const int w)
+                                   const Av1FilterLUT *lut, const int w
+                                   HIGHBD_DECL_SUFFIX)
 {
     const unsigned vm = vmask[0] | vmask[1] | vmask[2];
     for (unsigned x = 1; vm & ~(x - 1); x <<= 1, dst += 4, l++) {
@@ -191,7 +196,8 @@
             const int H = L >> 4;
             const int E = lut->e[L], I = lut->i[L];
             const int idx = (vmask[2] & x) ? 2 : !!(vmask[1] & x);
-            loop_filter(dst, E, I, H, 1, PXSTRIDE(stride), 4 << idx);
+            loop_filter(dst, E, I, H, 1, PXSTRIDE(stride), 4 << idx
+                        HIGHBD_TAIL_SUFFIX);
         }
     }
 }
@@ -199,7 +205,8 @@
 static void loop_filter_h_sb128uv_c(pixel *dst, const ptrdiff_t stride,
                                     const uint32_t *const vmask,
                                     const uint8_t (*l)[4], ptrdiff_t b4_stride,
-                                    const Av1FilterLUT *lut, const int h)
+                                    const Av1FilterLUT *lut, const int h
+                                    HIGHBD_DECL_SUFFIX)
 {
     const unsigned vm = vmask[0] | vmask[1];
     for (unsigned y = 1; vm & ~(y - 1);
@@ -211,7 +218,8 @@
             const int H = L >> 4;
             const int E = lut->e[L], I = lut->i[L];
             const int idx = !!(vmask[1] & y);
-            loop_filter(dst, E, I, H, PXSTRIDE(stride), 1, 4 + 2 * idx);
+            loop_filter(dst, E, I, H, PXSTRIDE(stride), 1, 4 + 2 * idx
+                        HIGHBD_TAIL_SUFFIX);
         }
     }
 }
@@ -219,7 +227,8 @@
 static void loop_filter_v_sb128uv_c(pixel *dst, const ptrdiff_t stride,
                                     const uint32_t *const vmask,
                                     const uint8_t (*l)[4], ptrdiff_t b4_stride,
-                                    const Av1FilterLUT *lut, const int w)
+                                    const Av1FilterLUT *lut, const int w
+                                    HIGHBD_DECL_SUFFIX)
 {
     const unsigned vm = vmask[0] | vmask[1];
     for (unsigned x = 1; vm & ~(x - 1); x <<= 1, dst += 4, l++) {
@@ -229,7 +238,8 @@
             const int H = L >> 4;
             const int E = lut->e[L], I = lut->i[L];
             const int idx = !!(vmask[1] & x);
-            loop_filter(dst, E, I, H, 1, PXSTRIDE(stride), 4 + 2 * idx);
+            loop_filter(dst, E, I, H, 1, PXSTRIDE(stride), 4 + 2 * idx
+                        HIGHBD_TAIL_SUFFIX);
         }
     }
 }
--- a/src/looprestoration.h
+++ b/src/looprestoration.h
@@ -55,7 +55,8 @@
             const_left_pixel_row left, \
             const pixel *lpf, ptrdiff_t lpf_stride, \
             int w, int h, const int16_t filterh[7], \
-            const int16_t filterv[7], enum LrEdgeFlags edges)
+            const int16_t filterv[7], enum LrEdgeFlags edges \
+            HIGHBD_DECL_SUFFIX)
 typedef decl_wiener_filter_fn(*wienerfilter_fn);
 
 #define decl_selfguided_filter_fn(name) \
@@ -63,7 +64,7 @@
             const_left_pixel_row left, \
             const pixel *lpf, ptrdiff_t lpf_stride, \
             int w, int h, int sgr_idx, const int16_t sgr_w[2], \
-            const enum LrEdgeFlags edges)
+            const enum LrEdgeFlags edges HIGHBD_DECL_SUFFIX)
 typedef decl_selfguided_filter_fn(*selfguided_fn);
 
 typedef struct Dav1dLoopRestorationDSPContext {
@@ -71,12 +72,8 @@
     selfguided_fn selfguided;
 } Dav1dLoopRestorationDSPContext;
 
-void dav1d_loop_restoration_dsp_init_8bpc(Dav1dLoopRestorationDSPContext *c);
-void dav1d_loop_restoration_dsp_init_10bpc(Dav1dLoopRestorationDSPContext *c);
-
-void dav1d_loop_restoration_dsp_init_arm_8bpc(Dav1dLoopRestorationDSPContext *c);
-void dav1d_loop_restoration_dsp_init_arm_10bpc(Dav1dLoopRestorationDSPContext *c);
-void dav1d_loop_restoration_dsp_init_x86_8bpc(Dav1dLoopRestorationDSPContext *c);
-void dav1d_loop_restoration_dsp_init_x86_10bpc(Dav1dLoopRestorationDSPContext *c);
+bitfn_decls(void dav1d_loop_restoration_dsp_init, Dav1dLoopRestorationDSPContext *c);
+bitfn_decls(void dav1d_loop_restoration_dsp_init_arm, Dav1dLoopRestorationDSPContext *c);
+bitfn_decls(void dav1d_loop_restoration_dsp_init_x86, Dav1dLoopRestorationDSPContext *c);
 
 #endif /* __DAV1D_SRC_LOOPRESTORATION_H__ */
--- a/src/looprestoration_tmpl.c
+++ b/src/looprestoration_tmpl.c
@@ -136,7 +136,7 @@
                      const pixel *lpf, const ptrdiff_t lpf_stride,
                      const int w, const int h,
                      const int16_t filterh[7], const int16_t filterv[7],
-                     const enum LrEdgeFlags edges)
+                     const enum LrEdgeFlags edges HIGHBD_DECL_SUFFIX)
 {
     // Wiener filtering is applied to a maximum stripe height of 64 + 3 pixels
     // of padding above and below
@@ -150,12 +150,13 @@
     uint16_t hor[70 /*(64 + 3 + 3)*/ * REST_UNIT_STRIDE];
     uint16_t *hor_ptr = hor;
 
-    const int round_bits_h = 3 + (BITDEPTH == 12) * 2;
+    const int bitdepth = bitdepth_from_max(bitdepth_max);
+    const int round_bits_h = 3 + (bitdepth == 12) * 2;
     const int rounding_off_h = 1 << (round_bits_h - 1);
-    const int clip_limit = 1 << ((BITDEPTH) + 1 + 7 - round_bits_h);
+    const int clip_limit = 1 << (bitdepth + 1 + 7 - round_bits_h);
     for (int j = 0; j < h + 6; j++) {
         for (int i = 0; i < w; i++) {
-            int sum = (tmp_ptr[i + 3] << 7) + (1 << (BITDEPTH + 6));
+            int sum = (tmp_ptr[i + 3] << 7) + (1 << (bitdepth + 6));
 
             for (int k = 0; k < 7; k++) {
                 sum += tmp_ptr[i + k] * filterh[k];
@@ -168,9 +169,9 @@
         hor_ptr += REST_UNIT_STRIDE;
     }
 
-    const int round_bits_v = 11 - (BITDEPTH == 12) * 2;
+    const int round_bits_v = 11 - (bitdepth == 12) * 2;
     const int rounding_off_v = 1 << (round_bits_v - 1);
-    const int round_offset = 1 << (BITDEPTH + (round_bits_v - 1));
+    const int round_offset = 1 << (bitdepth + (round_bits_v - 1));
     for (int i = 0; i < w; i++) {
         for (int j = 0; j < h; j++) {
             int sum = (hor[(j + 3) * REST_UNIT_STRIDE + i] << 7) - round_offset;
@@ -408,9 +409,10 @@
     }
 }
 
-static void selfguided_filter(int16_t *dst, const pixel *src,
+static void selfguided_filter(coef *dst, const pixel *src,
                               const ptrdiff_t src_stride, const int w,
-                              const int h, const int n, const int s)
+                              const int h, const int n, const int s
+                              HIGHBD_DECL_SUFFIX)
 {
     const int sgr_one_by_x = n == 25 ? 164 : 455;
 
@@ -431,6 +433,7 @@
         boxsum3(B_, src, w + 6, h + 6);
         boxsum3sqr(A_, src, w + 6, h + 6);
     }
+    const int bitdepth_min_8 = bitdepth_from_max(bitdepth_max) - 8;
 
     int32_t *AA = A - REST_UNIT_STRIDE;
     coef *BB = B - REST_UNIT_STRIDE;
@@ -437,9 +440,9 @@
     for (int j = -1; j < h + 1; j+= step) {
         for (int i = -1; i < w + 1; i++) {
             const int a =
-                (AA[i] + (1 << (2 * (BITDEPTH - 8)) >> 1)) >> (2 * (BITDEPTH - 8));
+                (AA[i] + ((1 << (2 * bitdepth_min_8)) >> 1)) >> (2 * bitdepth_min_8);
             const int b =
-                (BB[i] + (1 << (BITDEPTH - 8) >> 1)) >> (BITDEPTH - 8);
+                (BB[i] + ((1 << bitdepth_min_8) >> 1)) >> bitdepth_min_8;
 
             const unsigned p = imax(a * n - b * b, 0);
             const unsigned z = (p * s + (1 << 19)) >> 20;
@@ -446,7 +449,7 @@
 
             const int x = dav1d_sgr_x_by_xplus1[imin(z, 255)];
             // This is where we invert A and B, so that B is of size coef.
-            AA[i] = (((1 << 8) - x) * BB[i] * sgr_one_by_x + (1 << 11)) >> 12;
+            AA[i] = (((1U << 8) - x) * BB[i] * sgr_one_by_x + (1 << 11)) >> 12;
             BB[i] = x;
         }
         AA += step * REST_UNIT_STRIDE;
@@ -512,7 +515,8 @@
                          const pixel (*const left)[4],
                          const pixel *lpf, const ptrdiff_t lpf_stride,
                          const int w, const int h, const int sgr_idx,
-                         const int16_t sgr_w[2], const enum LrEdgeFlags edges)
+                         const int16_t sgr_w[2], const enum LrEdgeFlags edges
+                         HIGHBD_DECL_SUFFIX)
 {
     // Selfguided filter is applied to a maximum stripe height of 64 + 3 pixels
     // of padding above and below
@@ -522,12 +526,12 @@
 
     // Selfguided filter outputs to a maximum stripe height of 64 and a
     // maximum restoration width of 384 (256 * 1.5)
-    int16_t dst[64 * 384];
+    coef dst[64 * 384];
 
     // both r1 and r0 can't be zero
     if (!dav1d_sgr_params[sgr_idx][0]) {
         const int s1 = dav1d_sgr_params[sgr_idx][3];
-        selfguided_filter(dst, tmp, REST_UNIT_STRIDE, w, h, 9, s1);
+        selfguided_filter(dst, tmp, REST_UNIT_STRIDE, w, h, 9, s1 HIGHBD_TAIL_SUFFIX);
         const int w1 = (1 << 7) - sgr_w[1];
         for (int j = 0; j < h; j++) {
             for (int i = 0; i < w; i++) {
@@ -539,7 +543,7 @@
         }
     } else if (!dav1d_sgr_params[sgr_idx][1]) {
         const int s0 = dav1d_sgr_params[sgr_idx][2];
-        selfguided_filter(dst, tmp, REST_UNIT_STRIDE, w, h, 25, s0);
+        selfguided_filter(dst, tmp, REST_UNIT_STRIDE, w, h, 25, s0 HIGHBD_TAIL_SUFFIX);
         const int w0 = sgr_w[0];
         for (int j = 0; j < h; j++) {
             for (int i = 0; i < w; i++) {
@@ -550,13 +554,13 @@
             p += PXSTRIDE(p_stride);
         }
     } else {
-        int16_t dst1[64 * 384];
+        coef dst1[64 * 384];
         const int s0 = dav1d_sgr_params[sgr_idx][2];
         const int s1 = dav1d_sgr_params[sgr_idx][3];
         const int w0 = sgr_w[0];
         const int w1 = (1 << 7) - w0 - sgr_w[1];
-        selfguided_filter(dst, tmp, REST_UNIT_STRIDE, w, h, 25, s0);
-        selfguided_filter(dst1, tmp, REST_UNIT_STRIDE, w, h, 9, s1);
+        selfguided_filter(dst, tmp, REST_UNIT_STRIDE, w, h, 25, s0 HIGHBD_TAIL_SUFFIX);
+        selfguided_filter(dst1, tmp, REST_UNIT_STRIDE, w, h, 9, s1 HIGHBD_TAIL_SUFFIX);
         for (int j = 0; j < h; j++) {
             for (int i = 0; i < w; i++) {
                 const int u = (p[i] << 4);
--- a/src/lr_apply_tmpl.c
+++ b/src/lr_apply_tmpl.c
@@ -76,7 +76,7 @@
         while (row + stripe_h <= row_h) {
             f->dsp->mc.resize(dst, dst_stride, src, src_stride,
                               dst_w, src_w, 4, f->resize_step[ss_hor],
-                              f->resize_start[ss_hor]);
+                              f->resize_start[ss_hor] HIGHBD_CALL_SUFFIX);
             row += stripe_h; // unmodified stripe_h for the 1st stripe
             stripe_h = 64 >> ss_ver;
             src += stripe_h * PXSTRIDE(src_stride);
@@ -180,11 +180,11 @@
         }
         if (lr->type == DAV1D_RESTORATION_WIENER) {
             dsp->lr.wiener(p, p_stride, left, lpf, lpf_stride, unit_w, stripe_h,
-                           filterh, filterv, edges);
+                           filterh, filterv, edges HIGHBD_CALL_SUFFIX);
         } else {
             assert(lr->type == DAV1D_RESTORATION_SGRPROJ);
             dsp->lr.selfguided(p, p_stride, left, lpf, lpf_stride, unit_w, stripe_h,
-                               lr->sgr_idx, lr->sgr_weights, edges);
+                               lr->sgr_idx, lr->sgr_weights, edges HIGHBD_CALL_SUFFIX);
         }
 
         left += stripe_h;
--- a/src/mc.h
+++ b/src/mc.h
@@ -38,57 +38,59 @@
 #define decl_mc_fn(name) \
 void (name)(pixel *dst, ptrdiff_t dst_stride, \
             const pixel *src, ptrdiff_t src_stride, \
-            int w, int h, int mx, int my)
+            int w, int h, int mx, int my HIGHBD_DECL_SUFFIX)
 typedef decl_mc_fn(*mc_fn);
 
 #define decl_mc_scaled_fn(name) \
 void (name)(pixel *dst, ptrdiff_t dst_stride, \
             const pixel *src, ptrdiff_t src_stride, \
-            int w, int h, int mx, int my, int dx, int dy)
+            int w, int h, int mx, int my, int dx, int dy HIGHBD_DECL_SUFFIX)
 typedef decl_mc_scaled_fn(*mc_scaled_fn);
 
 #define decl_warp8x8_fn(name) \
 void (name)(pixel *dst, ptrdiff_t dst_stride, \
             const pixel *src, ptrdiff_t src_stride, \
-            const int16_t *abcd, int mx, int my)
+            const int16_t *abcd, int mx, int my HIGHBD_DECL_SUFFIX)
 typedef decl_warp8x8_fn(*warp8x8_fn);
 
 #define decl_mct_fn(name) \
 void (name)(int16_t *tmp, const pixel *src, ptrdiff_t src_stride, \
-            int w, int h, int mx, int my)
+            int w, int h, int mx, int my HIGHBD_DECL_SUFFIX)
 typedef decl_mct_fn(*mct_fn);
 
 #define decl_mct_scaled_fn(name) \
 void (name)(int16_t *tmp, const pixel *src, ptrdiff_t src_stride, \
-            int w, int h, int mx, int my, int dx, int dy)
+            int w, int h, int mx, int my, int dx, int dy HIGHBD_DECL_SUFFIX)
 typedef decl_mct_scaled_fn(*mct_scaled_fn);
 
 #define decl_warp8x8t_fn(name) \
 void (name)(int16_t *tmp, const ptrdiff_t tmp_stride, \
             const pixel *src, ptrdiff_t src_stride, \
-            const int16_t *abcd, int mx, int my)
+            const int16_t *abcd, int mx, int my HIGHBD_DECL_SUFFIX)
 typedef decl_warp8x8t_fn(*warp8x8t_fn);
 
 #define decl_avg_fn(name) \
 void (name)(pixel *dst, ptrdiff_t dst_stride, \
-            const int16_t *tmp1, const int16_t *tmp2, int w, int h)
+            const int16_t *tmp1, const int16_t *tmp2, int w, int h \
+            HIGHBD_DECL_SUFFIX)
 typedef decl_avg_fn(*avg_fn);
 
 #define decl_w_avg_fn(name) \
 void (name)(pixel *dst, ptrdiff_t dst_stride, \
-            const int16_t *tmp1, const int16_t *tmp2, int w, int h, int weight)
+            const int16_t *tmp1, const int16_t *tmp2, int w, int h, int weight \
+            HIGHBD_DECL_SUFFIX)
 typedef decl_w_avg_fn(*w_avg_fn);
 
 #define decl_mask_fn(name) \
 void (name)(pixel *dst, ptrdiff_t dst_stride, \
             const int16_t *tmp1, const int16_t *tmp2, int w, int h, \
-            const uint8_t *mask)
+            const uint8_t *mask HIGHBD_DECL_SUFFIX)
 typedef decl_mask_fn(*mask_fn);
 
 #define decl_w_mask_fn(name) \
 void (name)(pixel *dst, ptrdiff_t dst_stride, \
             const int16_t *tmp1, const int16_t *tmp2, int w, int h, \
-            uint8_t *mask, int sign)
+            uint8_t *mask, int sign HIGHBD_DECL_SUFFIX)
 typedef decl_w_mask_fn(*w_mask_fn);
 
 #define decl_blend_fn(name) \
@@ -108,7 +110,7 @@
 #define decl_resize_fn(name) \
 void (name)(pixel *dst, ptrdiff_t dst_stride, \
             const pixel *src, ptrdiff_t src_stride, \
-            int dst_w, int src_w, int h, int dx, int mx)
+            int dst_w, int src_w, int h, int dx, int mx HIGHBD_DECL_SUFFIX)
 typedef decl_resize_fn(*resize_fn);
 
 typedef struct Dav1dMCDSPContext {
@@ -129,13 +131,8 @@
     resize_fn resize;
 } Dav1dMCDSPContext;
 
-void dav1d_mc_dsp_init_8bpc(Dav1dMCDSPContext *c);
-void dav1d_mc_dsp_init_10bpc(Dav1dMCDSPContext *c);
-
-void dav1d_mc_dsp_init_arm_8bpc(Dav1dMCDSPContext *c);
-void dav1d_mc_dsp_init_arm_10bpc(Dav1dMCDSPContext *c);
-
-void dav1d_mc_dsp_init_x86_8bpc(Dav1dMCDSPContext *c);
-void dav1d_mc_dsp_init_x86_10bpc(Dav1dMCDSPContext *c);
+bitfn_decls(void dav1d_mc_dsp_init, Dav1dMCDSPContext *c);
+bitfn_decls(void dav1d_mc_dsp_init_arm, Dav1dMCDSPContext *c);
+bitfn_decls(void dav1d_mc_dsp_init_x86, Dav1dMCDSPContext *c);
 
 #endif /* __DAV1D_SRC_MC_H__ */
--- a/src/mc_tmpl.c
+++ b/src/mc_tmpl.c
@@ -37,6 +37,13 @@
 #include "src/mc.h"
 #include "src/tables.h"
 
+#if BITDEPTH == 8
+#define get_intermediate_bits(bitdepth_max) 4
+#else
+// 4 for 10 bits/component, 2 for 12 bits/component
+#define get_intermediate_bits(bitdepth_max) (14 - bitdepth_from_max(bitdepth_max))
+#endif
+
 static NOINLINE void
 put_c(pixel *dst, const ptrdiff_t dst_stride,
       const pixel *src, const ptrdiff_t src_stride, const int w, int h)
@@ -51,11 +58,12 @@
 
 static NOINLINE void
 prep_c(int16_t *tmp, const pixel *src, const ptrdiff_t src_stride,
-       const int w, int h)
+       const int w, int h HIGHBD_DECL_SUFFIX)
 {
+    const int intermediate_bits = get_intermediate_bits(bitdepth_max);
     do {
         for (int x = 0; x < w; x++)
-            tmp[x] = src[x] << 4;
+            tmp[x] = src[x] << intermediate_bits;
 
         tmp += w;
         src += src_stride;
@@ -73,7 +81,7 @@
      F[7] * src[x + +4 * stride])
 
 #define DAV1D_FILTER_8TAP_RND(src, x, F, stride, sh) \
-    ((FILTER_8TAP(src, x, F, stride) + ((1 << sh) >> 1)) >> sh)
+    ((FILTER_8TAP(src, x, F, stride) + ((1 << (sh)) >> 1)) >> (sh))
 
 #define DAV1D_FILTER_8TAP_CLIP(src, x, F, stride, sh) \
     iclip_pixel(DAV1D_FILTER_8TAP_RND(src, x, F, stride, sh))
@@ -96,8 +104,11 @@
 put_8tap_c(pixel *dst, ptrdiff_t dst_stride,
            const pixel *src, ptrdiff_t src_stride,
            const int w, int h, const int mx, const int my,
-           const int filter_type)
+           const int filter_type HIGHBD_DECL_SUFFIX)
 {
+    const int intermediate_bits = get_intermediate_bits(bitdepth_max);
+    const int intermediate_rnd = (1 << intermediate_bits) >> 1;
+
     GET_FILTERS();
     dst_stride = PXSTRIDE(dst_stride);
     src_stride = PXSTRIDE(src_stride);
@@ -110,7 +121,8 @@
             src -= src_stride * 3;
             do {
                 for (int x = 0; x < w; x++)
-                    mid_ptr[x] = DAV1D_FILTER_8TAP_RND(src, x, fh, 1, 2);
+                    mid_ptr[x] = DAV1D_FILTER_8TAP_RND(src, x, fh, 1,
+                                                       6 - intermediate_bits);
 
                 mid_ptr += 128;
                 src += src_stride;
@@ -119,7 +131,8 @@
             mid_ptr = mid + 128 * 3;
             do {
                 for (int x = 0; x < w; x++)
-                    dst[x] = DAV1D_FILTER_8TAP_CLIP(mid_ptr, x, fv, 128, 10);
+                    dst[x] = DAV1D_FILTER_8TAP_CLIP(mid_ptr, x, fv, 128,
+                                                    6 + intermediate_bits);
 
                 mid_ptr += 128;
                 dst += dst_stride;
@@ -127,8 +140,9 @@
         } else {
             do {
                 for (int x = 0; x < w; x++) {
-                    const int px = DAV1D_FILTER_8TAP_RND(src, x, fh, 1, 2);
-                    dst[x] = iclip_pixel((px + 8) >> 4);
+                    const int px = DAV1D_FILTER_8TAP_RND(src, x, fh, 1,
+                                                         6 - intermediate_bits);
+                    dst[x] = iclip_pixel((px + intermediate_rnd) >> intermediate_bits);
                 }
 
                 dst += dst_stride;
@@ -151,8 +165,11 @@
 put_8tap_scaled_c(pixel *dst, const ptrdiff_t dst_stride,
                   const pixel *src, ptrdiff_t src_stride,
                   const int w, int h, const int mx, int my,
-                  const int dx, const int dy, const int filter_type)
+                  const int dx, const int dy, const int filter_type
+                  HIGHBD_DECL_SUFFIX)
 {
+    const int intermediate_bits = get_intermediate_bits(bitdepth_max);
+    const int intermediate_rnd = (1 << intermediate_bits) >> 1;
     int tmp_h = (((h - 1) * dy + my) >> 10) + 8;
     int16_t mid[128 * (256 + 7)], *mid_ptr = mid;
     src_stride = PXSTRIDE(src_stride);
@@ -164,7 +181,9 @@
 
         for (x = 0; x < w; x++) {
             GET_H_FILTER(imx >> 6);
-            mid_ptr[x] = fh ? DAV1D_FILTER_8TAP_RND(src, ioff, fh, 1, 2) : src[ioff] << 4;
+            mid_ptr[x] = fh ? DAV1D_FILTER_8TAP_RND(src, ioff, fh, 1,
+                                                    6 - intermediate_bits) :
+                              src[ioff] << intermediate_bits;
             imx += dx;
             ioff += imx >> 10;
             imx &= 0x3ff;
@@ -180,8 +199,10 @@
         GET_V_FILTER(my >> 6);
 
         for (x = 0; x < w; x++)
-            dst[x] = fv ? DAV1D_FILTER_8TAP_CLIP(mid_ptr, x, fv, 128, 10) :
-                          iclip_pixel((mid_ptr[x] + 8) >> 4);
+            dst[x] = fv ? DAV1D_FILTER_8TAP_CLIP(mid_ptr, x, fv, 128,
+                                                 6 + intermediate_bits) :
+                          iclip_pixel((mid_ptr[x] + intermediate_rnd) >>
+                                              intermediate_bits);
 
         my += dy;
         mid_ptr += (my >> 10) * 128;
@@ -193,8 +214,9 @@
 static NOINLINE void
 prep_8tap_c(int16_t *tmp, const pixel *src, ptrdiff_t src_stride,
             const int w, int h, const int mx, const int my,
-            const int filter_type)
+            const int filter_type HIGHBD_DECL_SUFFIX)
 {
+    const int intermediate_bits = get_intermediate_bits(bitdepth_max);
     GET_FILTERS();
     src_stride = PXSTRIDE(src_stride);
 
@@ -206,7 +228,8 @@
             src -= src_stride * 3;
             do {
                 for (int x = 0; x < w; x++)
-                    mid_ptr[x] = DAV1D_FILTER_8TAP_RND(src, x, fh, 1, 2);
+                    mid_ptr[x] = DAV1D_FILTER_8TAP_RND(src, x, fh, 1,
+                                                       6 - intermediate_bits);
 
                 mid_ptr += 128;
                 src += src_stride;
@@ -223,7 +246,8 @@
         } else {
             do {
                 for (int x = 0; x < w; x++)
-                    tmp[x] = DAV1D_FILTER_8TAP_RND(src, x, fh, 1, 2);
+                    tmp[x] = DAV1D_FILTER_8TAP_RND(src, x, fh, 1,
+                                                   6 - intermediate_bits);
 
                 tmp += w;
                 src += src_stride;
@@ -232,20 +256,23 @@
     } else if (fv) {
         do {
             for (int x = 0; x < w; x++)
-                tmp[x] = DAV1D_FILTER_8TAP_RND(src, x, fv, src_stride, 2);
+                tmp[x] = DAV1D_FILTER_8TAP_RND(src, x, fv, src_stride,
+                                               6 - intermediate_bits);
 
             tmp += w;
             src += src_stride;
         } while (--h);
     } else
-        prep_c(tmp, src, src_stride, w, h);
+        prep_c(tmp, src, src_stride, w, h HIGHBD_TAIL_SUFFIX);
 }
 
 static NOINLINE void
 prep_8tap_scaled_c(int16_t *tmp, const pixel *src, ptrdiff_t src_stride,
                    const int w, int h, const int mx, int my,
-                   const int dx, const int dy, const int filter_type)
+                   const int dx, const int dy, const int filter_type
+                   HIGHBD_DECL_SUFFIX)
 {
+    const int intermediate_bits = get_intermediate_bits(bitdepth_max);
     int tmp_h = (((h - 1) * dy + my) >> 10) + 8;
     int16_t mid[128 * (256 + 7)], *mid_ptr = mid;
     src_stride = PXSTRIDE(src_stride);
@@ -257,7 +284,9 @@
 
         for (x = 0; x < w; x++) {
             GET_H_FILTER(imx >> 6);
-            mid_ptr[x] = fh ? DAV1D_FILTER_8TAP_RND(src, ioff, fh, 1, 2) : src[ioff] << 4;
+            mid_ptr[x] = fh ? DAV1D_FILTER_8TAP_RND(src, ioff, fh, 1,
+                                                    6 - intermediate_bits) :
+                              src[ioff] << intermediate_bits;
             imx += dx;
             ioff += imx >> 10;
             imx &= 0x3ff;
@@ -288,10 +317,11 @@
                                 const pixel *const src, \
                                 const ptrdiff_t src_stride, \
                                 const int w, const int h, \
-                                const int mx, const int my) \
+                                const int mx, const int my \
+                                HIGHBD_DECL_SUFFIX) \
 { \
     put_8tap_c(dst, dst_stride, src, src_stride, w, h, mx, my, \
-               type_h | (type_v << 2)); \
+               type_h | (type_v << 2) HIGHBD_TAIL_SUFFIX); \
 } \
 static void put_8tap_##type##_scaled_c(pixel *const dst, \
                                        const ptrdiff_t dst_stride, \
@@ -299,19 +329,21 @@
                                        const ptrdiff_t src_stride, \
                                        const int w, const int h, \
                                        const int mx, const int my, \
-                                       const int dx, const int dy) \
+                                       const int dx, const int dy \
+                                       HIGHBD_DECL_SUFFIX) \
 { \
     put_8tap_scaled_c(dst, dst_stride, src, src_stride, w, h, mx, my, dx, dy, \
-                      type_h | (type_v << 2)); \
+                      type_h | (type_v << 2) HIGHBD_TAIL_SUFFIX); \
 } \
 static void prep_8tap_##type##_c(int16_t *const tmp, \
                                  const pixel *const src, \
                                  const ptrdiff_t src_stride, \
                                  const int w, const int h, \
-                                 const int mx, const int my) \
+                                 const int mx, const int my \
+                                 HIGHBD_DECL_SUFFIX) \
 { \
     prep_8tap_c(tmp, src, src_stride, w, h, mx, my, \
-                type_h | (type_v << 2)); \
+                type_h | (type_v << 2) HIGHBD_TAIL_SUFFIX); \
 } \
 static void prep_8tap_##type##_scaled_c(int16_t *const tmp, \
                                         const pixel *const src, \
@@ -318,10 +350,11 @@
                                         const ptrdiff_t src_stride, \
                                         const int w, const int h, \
                                         const int mx, const int my, \
-                                        const int dx, const int dy) \
+                                        const int dx, const int dy \
+                                        HIGHBD_DECL_SUFFIX) \
 { \
     prep_8tap_scaled_c(tmp, src, src_stride, w, h, mx, my, dx, dy, \
-                       type_h | (type_v << 2)); \
+                       type_h | (type_v << 2) HIGHBD_TAIL_SUFFIX); \
 }
 
 filter_fns(regular,        DAV1D_FILTER_8TAP_REGULAR, DAV1D_FILTER_8TAP_REGULAR)
@@ -338,7 +371,7 @@
     (16 * src[x] + ((mxy) * (src[x + stride] - src[x])))
 
 #define FILTER_BILIN_RND(src, x, mxy, stride, sh) \
-    ((FILTER_BILIN(src, x, mxy, stride) + ((1 << sh) >> 1)) >> sh)
+    ((FILTER_BILIN(src, x, mxy, stride) + ((1 << (sh)) >> 1)) >> (sh))
 
 #define FILTER_BILIN_CLIP(src, x, mxy, stride, sh) \
     iclip_pixel(FILTER_BILIN_RND(src, x, mxy, stride, sh))
@@ -345,8 +378,11 @@
 
 static void put_bilin_c(pixel *dst, ptrdiff_t dst_stride,
                         const pixel *src, ptrdiff_t src_stride,
-                        const int w, int h, const int mx, const int my)
+                        const int w, int h, const int mx, const int my
+                        HIGHBD_DECL_SUFFIX)
 {
+    const int intermediate_bits = get_intermediate_bits(bitdepth_max);
+    const int intermediate_rnd = (1 << intermediate_bits) >> 1;
     dst_stride = PXSTRIDE(dst_stride);
     src_stride = PXSTRIDE(src_stride);
 
@@ -357,7 +393,8 @@
 
             do {
                 for (int x = 0; x < w; x++)
-                    mid_ptr[x] = FILTER_BILIN(src, x, mx, 1);
+                    mid_ptr[x] = FILTER_BILIN_RND(src, x, mx, 1,
+                                                  4 - intermediate_bits);
 
                 mid_ptr += 128;
                 src += src_stride;
@@ -366,7 +403,8 @@
             mid_ptr = mid;
             do {
                 for (int x = 0; x < w; x++)
-                    dst[x] = FILTER_BILIN_CLIP(mid_ptr, x, my, 128, 8);
+                    dst[x] = FILTER_BILIN_CLIP(mid_ptr, x, my, 128,
+                                               4 + intermediate_bits);
 
                 mid_ptr += 128;
                 dst += dst_stride;
@@ -373,8 +411,11 @@
             } while (--h);
         } else {
             do {
-                for (int x = 0; x < w; x++)
-                    dst[x] = FILTER_BILIN_CLIP(src, x, mx, 1, 4);
+                for (int x = 0; x < w; x++) {
+                    const int px = FILTER_BILIN_RND(src, x, mx, 1,
+                                                    4 - intermediate_bits);
+                    dst[x] = iclip_pixel((px + intermediate_rnd) >> intermediate_bits);
+                }
 
                 dst += dst_stride;
                 src += src_stride;
@@ -395,8 +436,10 @@
 static void put_bilin_scaled_c(pixel *dst, ptrdiff_t dst_stride,
                                const pixel *src, ptrdiff_t src_stride,
                                const int w, int h, const int mx, int my,
-                               const int dx, const int dy)
+                               const int dx, const int dy
+                               HIGHBD_DECL_SUFFIX)
 {
+    const int intermediate_bits = get_intermediate_bits(bitdepth_max);
     int tmp_h = (((h - 1) * dy + my) >> 10) + 2;
     int16_t mid[128 * (256 + 1)], *mid_ptr = mid;
 
@@ -405,7 +448,8 @@
         int imx = mx, ioff = 0;
 
         for (x = 0; x < w; x++) {
-            mid_ptr[x] = FILTER_BILIN(src, ioff, imx >> 6, 1);
+            mid_ptr[x] = FILTER_BILIN_RND(src, ioff, imx >> 6, 1,
+                                          4 - intermediate_bits);
             imx += dx;
             ioff += imx >> 10;
             imx &= 0x3ff;
@@ -420,7 +464,8 @@
         int x;
 
         for (x = 0; x < w; x++)
-            dst[x] = FILTER_BILIN_CLIP(mid_ptr, x, my >> 6, 128, 8);
+            dst[x] = FILTER_BILIN_CLIP(mid_ptr, x, my >> 6, 128,
+                                       4 + intermediate_bits);
 
         my += dy;
         mid_ptr += (my >> 10) * 128;
@@ -431,8 +476,10 @@
 
 static void prep_bilin_c(int16_t *tmp,
                          const pixel *src, ptrdiff_t src_stride,
-                         const int w, int h, const int mx, const int my)
+                         const int w, int h, const int mx, const int my
+                         HIGHBD_DECL_SUFFIX)
 {
+    const int intermediate_bits = get_intermediate_bits(bitdepth_max);
     src_stride = PXSTRIDE(src_stride);
 
     if (mx) {
@@ -442,7 +489,8 @@
 
             do {
                 for (int x = 0; x < w; x++)
-                    mid_ptr[x] = FILTER_BILIN(src, x, mx, 1);
+                    mid_ptr[x] = FILTER_BILIN_RND(src, x, mx, 1,
+                                                  4 - intermediate_bits);
 
                 mid_ptr += 128;
                 src += src_stride;
@@ -459,7 +507,8 @@
         } else {
             do {
                 for (int x = 0; x < w; x++)
-                    tmp[x] = FILTER_BILIN(src, x, mx, 1);
+                    tmp[x] = FILTER_BILIN_RND(src, x, mx, 1,
+                                              4 - intermediate_bits);
 
                 tmp += w;
                 src += src_stride;
@@ -468,20 +517,22 @@
     } else if (my) {
         do {
             for (int x = 0; x < w; x++)
-                tmp[x] = FILTER_BILIN(src, x, my, src_stride);
+                tmp[x] = FILTER_BILIN_RND(src, x, my, src_stride,
+                                          4 - intermediate_bits);
 
             tmp += w;
             src += src_stride;
         } while (--h);
     } else
-        prep_c(tmp, src, src_stride, w, h);
+        prep_c(tmp, src, src_stride, w, h HIGHBD_TAIL_SUFFIX);
 }
 
 static void prep_bilin_scaled_c(int16_t *tmp,
                                 const pixel *src, ptrdiff_t src_stride,
                                 const int w, int h, const int mx, int my,
-                                const int dx, const int dy)
+                                const int dx, const int dy HIGHBD_DECL_SUFFIX)
 {
+    const int intermediate_bits = get_intermediate_bits(bitdepth_max);
     int tmp_h = (((h - 1) * dy + my) >> 10) + 2;
     int16_t mid[128 * (256 + 1)], *mid_ptr = mid;
 
@@ -490,7 +541,8 @@
         int imx = mx, ioff = 0;
 
         for (x = 0; x < w; x++) {
-            mid_ptr[x] = FILTER_BILIN(src, ioff, imx >> 6, 1);
+            mid_ptr[x] = FILTER_BILIN_RND(src, ioff, imx >> 6, 1,
+                                          4 - intermediate_bits);
             imx += dx;
             ioff += imx >> 10;
             imx &= 0x3ff;
@@ -515,11 +567,14 @@
 }
 
 static void avg_c(pixel *dst, const ptrdiff_t dst_stride,
-                  const int16_t *tmp1, const int16_t *tmp2, const int w, int h)
+                  const int16_t *tmp1, const int16_t *tmp2, const int w, int h
+                  HIGHBD_DECL_SUFFIX)
 {
+    const int intermediate_bits = get_intermediate_bits(bitdepth_max);
+    const int sh = intermediate_bits + 1, rnd = 1 << intermediate_bits;
     do {
         for (int x = 0; x < w; x++)
-            dst[x] = iclip_pixel((tmp1[x] + tmp2[x] + 16) >> 5);
+            dst[x] = iclip_pixel((tmp1[x] + tmp2[x] + rnd) >> sh);
 
         tmp1 += w;
         tmp2 += w;
@@ -529,12 +584,14 @@
 
 static void w_avg_c(pixel *dst, const ptrdiff_t dst_stride,
                     const int16_t *tmp1, const int16_t *tmp2, const int w, int h,
-                    const int weight)
+                    const int weight HIGHBD_DECL_SUFFIX)
 {
+    const int intermediate_bits = get_intermediate_bits(bitdepth_max);
+    const int sh = intermediate_bits + 4, rnd = 8 << intermediate_bits;
     do {
         for (int x = 0; x < w; x++)
             dst[x] = iclip_pixel((tmp1[x] * weight +
-                                  tmp2[x] * (16 - weight) + 128) >> 8);
+                                  tmp2[x] * (16 - weight) + rnd) >> sh);
 
         tmp1 += w;
         tmp2 += w;
@@ -544,12 +601,14 @@
 
 static void mask_c(pixel *dst, const ptrdiff_t dst_stride,
                    const int16_t *tmp1, const int16_t *tmp2, const int w, int h,
-                   const uint8_t *mask)
+                   const uint8_t *mask HIGHBD_DECL_SUFFIX)
 {
+    const int intermediate_bits = get_intermediate_bits(bitdepth_max);
+    const int sh = intermediate_bits + 6, rnd = 32 << intermediate_bits;
     do {
         for (int x = 0; x < w; x++)
             dst[x] = iclip_pixel((tmp1[x] * mask[x] +
-                                  tmp2[x] * (64 - mask[x]) + 512) >> 10);
+                                  tmp2[x] * (64 - mask[x]) + rnd) >> sh);
 
         tmp1 += w;
         tmp2 += w;
@@ -603,23 +662,27 @@
 static void w_mask_c(pixel *dst, const ptrdiff_t dst_stride,
                      const int16_t *tmp1, const int16_t *tmp2, const int w, int h,
                      uint8_t *mask, const int sign,
-                     const int ss_hor, const int ss_ver)
+                     const int ss_hor, const int ss_ver HIGHBD_DECL_SUFFIX)
 {
     // store mask at 2x2 resolution, i.e. store 2x1 sum for even rows,
     // and then load this intermediate to calculate final value for odd rows
-    const int rnd = 8 << (BITDEPTH - 8);
+    const int intermediate_bits = get_intermediate_bits(bitdepth_max);
+    const int bitdepth = bitdepth_from_max(bitdepth_max);
+    const int sh = intermediate_bits + 6, rnd = 32 << intermediate_bits;
+    const int mask_sh = bitdepth + intermediate_bits - 4;
+    const int mask_rnd = 1 << (mask_sh - 5);
     do {
         for (int x = 0; x < w; x++) {
-            const int m = imin(38 + ((abs(tmp1[x] - tmp2[x]) + rnd) >> BITDEPTH), 64);
+            const int m = imin(38 + ((abs(tmp1[x] - tmp2[x]) + mask_rnd) >> mask_sh), 64);
             dst[x] = iclip_pixel((tmp1[x] * m +
-                                  tmp2[x] * (64 - m) + 512) >> 10);
+                                  tmp2[x] * (64 - m) + rnd) >> sh);
 
             if (ss_hor) {
                 x++;
 
-                const int n = imin(38 + ((abs(tmp1[x] - tmp2[x]) + rnd) >> BITDEPTH), 64);
+                const int n = imin(38 + ((abs(tmp1[x] - tmp2[x]) + mask_rnd) >> mask_sh), 64);
                 dst[x] = iclip_pixel((tmp1[x] * n +
-                                      tmp2[x] * (64 - n) + 512) >> 10);
+                                      tmp2[x] * (64 - n) + rnd) >> sh);
 
                 if (h & ss_ver) {
                     mask[x >> 1] = (m + n + mask[x >> 1] + 2 - sign) >> 2;
@@ -644,9 +707,10 @@
 static void w_mask_##ssn##_c(pixel *const dst, const ptrdiff_t dst_stride, \
                              const int16_t *const tmp1, const int16_t *const tmp2, \
                              const int w, const int h, uint8_t *mask, \
-                             const int sign) \
+                             const int sign HIGHBD_DECL_SUFFIX) \
 { \
-    w_mask_c(dst, dst_stride, tmp1, tmp2, w, h, mask, sign, ss_hor, ss_ver); \
+    w_mask_c(dst, dst_stride, tmp1, tmp2, w, h, mask, sign, ss_hor, ss_ver \
+             HIGHBD_TAIL_SUFFIX); \
 }
 
 w_mask_fns(444, 0, 0);
@@ -666,7 +730,7 @@
      F[7] * src[x + +4 * stride])
 
 #define FILTER_WARP_RND(src, x, F, stride, sh) \
-    ((FILTER_WARP(src, x, F, stride) + ((1 << sh) >> 1)) >> sh)
+    ((FILTER_WARP(src, x, F, stride) + ((1 << (sh)) >> 1)) >> (sh))
 
 #define FILTER_WARP_CLIP(src, x, F, stride, sh) \
     iclip_pixel(FILTER_WARP_RND(src, x, F, stride, sh))
@@ -673,8 +737,10 @@
 
 static void warp_affine_8x8_c(pixel *dst, const ptrdiff_t dst_stride,
                               const pixel *src, const ptrdiff_t src_stride,
-                              const int16_t *const abcd, int mx, int my)
+                              const int16_t *const abcd, int mx, int my
+                              HIGHBD_DECL_SUFFIX)
 {
+    const int intermediate_bits = get_intermediate_bits(bitdepth_max);
     int16_t mid[15 * 8], *mid_ptr = mid;
 
     src -= 3 * PXSTRIDE(src_stride);
@@ -683,7 +749,8 @@
             const int8_t *const filter =
                 dav1d_mc_warp_filter[64 + ((tmx + 512) >> 10)];
 
-            mid_ptr[x] = FILTER_WARP_RND(src, x, filter, 1, 3);
+            mid_ptr[x] = FILTER_WARP_RND(src, x, filter, 1,
+                                         7 - intermediate_bits);
         }
         src += PXSTRIDE(src_stride);
         mid_ptr += 8;
@@ -695,7 +762,8 @@
             const int8_t *const filter =
                 dav1d_mc_warp_filter[64 + ((tmy + 512) >> 10)];
 
-            dst[x] = FILTER_WARP_CLIP(mid_ptr, x, filter, 8, 11);
+            dst[x] = FILTER_WARP_CLIP(mid_ptr, x, filter, 8,
+                                      7 + intermediate_bits);
         }
         mid_ptr += 8;
         dst += PXSTRIDE(dst_stride);
@@ -704,8 +772,10 @@
 
 static void warp_affine_8x8t_c(int16_t *tmp, const ptrdiff_t tmp_stride,
                                const pixel *src, const ptrdiff_t src_stride,
-                               const int16_t *const abcd, int mx, int my)
+                               const int16_t *const abcd, int mx, int my
+                               HIGHBD_DECL_SUFFIX)
 {
+    const int intermediate_bits = get_intermediate_bits(bitdepth_max);
     int16_t mid[15 * 8], *mid_ptr = mid;
 
     src -= 3 * PXSTRIDE(src_stride);
@@ -714,7 +784,8 @@
             const int8_t *const filter =
                 dav1d_mc_warp_filter[64 + ((tmx + 512) >> 10)];
 
-            mid_ptr[x] = FILTER_WARP_RND(src, x, filter, 1, 3);
+            mid_ptr[x] = FILTER_WARP_RND(src, x, filter, 1,
+                                         7 - intermediate_bits);
         }
         src += PXSTRIDE(src_stride);
         mid_ptr += 8;
@@ -785,7 +856,7 @@
 static void resize_c(pixel *dst, const ptrdiff_t dst_stride,
                      const pixel *src, const ptrdiff_t src_stride,
                      const int dst_w, const int src_w, int h,
-                     const int dx, const int mx0)
+                     const int dx, const int mx0 HIGHBD_DECL_SUFFIX)
 {
     do {
         int mx = mx0, src_x = -1;
--- a/src/meson.build
+++ b/src/meson.build
@@ -52,9 +52,9 @@
 # These files are compiled for each bitdepth with
 # `BITDEPTH` defined to the currently built bitdepth.
 libdav1d_tmpl_sources = files(
+    'ipred_prepare_tmpl.c',
     'ipred_tmpl.c',
     'itx_tmpl.c',
-    'ipred_prepare_tmpl.c',
     'lf_apply_tmpl.c',
     'loopfilter_tmpl.c',
     'mc_tmpl.c',
--- a/src/recon_tmpl.c
+++ b/src/recon_tmpl.c
@@ -208,6 +208,9 @@
     const uint16_t *const dq_tbl = ts->dq[b->seg_id][plane];
     const uint8_t *const qm_tbl = f->qm[is_1d || *txtp == IDTX][tx][plane];
     const int dq_shift = imax(0, t_dim->ctx - 2);
+    const int bitdepth = BITDEPTH == 8 ? 8 : f->cur.p.bpc;
+    const int cf_min = -(1 << (7 + bitdepth));
+    const int cf_max = (1 << (7 + bitdepth)) - 1;
     for (int i = 0; i <= eob; i++) {
         const int rc = scan[i];
         int tok = cf[rc];
@@ -247,9 +250,7 @@
         // dequant, see 7.12.3
         cul_level += tok;
         tok = (((int64_t)dq * tok) & 0xffffff) >> dq_shift;
-        cf[rc] = iclip(sign ? -tok : tok,
-                       -(1 << (7 + BITDEPTH)),
-                       (1 << (7 + BITDEPTH)) - 1);
+        cf[rc] = iclip(sign ? -tok : tok, cf_min, cf_max);
     }
 
     // context
@@ -349,7 +350,8 @@
             if (eob >= 0) {
                 if (DEBUG_BLOCK_INFO && DEBUG_B_PIXELS)
                     coef_dump(cf, imin(t_dim->h, 8) * 4, imin(t_dim->w, 8) * 4, 3, "dq");
-                dsp->itx.itxfm_add[ytx][txtp](dst, f->cur.stride[0], cf, eob);
+                dsp->itx.itxfm_add[ytx][txtp](dst, f->cur.stride[0], cf, eob
+                                              HIGHBD_CALL_SUFFIX);
                 if (DEBUG_BLOCK_INFO && DEBUG_B_PIXELS)
                     hex_dump(dst, f->cur.stride[0], t_dim->w * 4, t_dim->h * 4, "recon");
             }
@@ -542,10 +544,12 @@
 
         if (dst8 != NULL) {
             f->dsp->mc.mc[filter_2d](dst8, dst_stride, ref, ref_stride, bw4 * h_mul,
-                                     bh4 * v_mul, mx << !ss_hor, my << !ss_ver);
+                                     bh4 * v_mul, mx << !ss_hor, my << !ss_ver
+                                     HIGHBD_CALL_SUFFIX);
         } else {
             f->dsp->mc.mct[filter_2d](dst16, ref, ref_stride, bw4 * h_mul,
-                                      bh4 * v_mul, mx << !ss_hor, my << !ss_ver);
+                                      bh4 * v_mul, mx << !ss_hor, my << !ss_ver
+                                      HIGHBD_CALL_SUFFIX);
         }
     } else {
         assert(refp != &f->sr_cur);
@@ -594,13 +598,15 @@
                                             bw4 * h_mul, bh4 * v_mul,
                                             pos_x & 0x3ff, pos_y & 0x3ff,
                                             f->svc[refidx][0].step,
-                                            f->svc[refidx][1].step);
+                                            f->svc[refidx][1].step
+                                            HIGHBD_CALL_SUFFIX);
         } else {
             f->dsp->mc.mct_scaled[filter_2d](dst16, ref, ref_stride,
                                              bw4 * h_mul, bh4 * v_mul,
                                              pos_x & 0x3ff, pos_y & 0x3ff,
                                              f->svc[refidx][0].step,
-                                             f->svc[refidx][1].step);
+                                             f->svc[refidx][1].step
+                                             HIGHBD_CALL_SUFFIX);
         }
     }
 
@@ -722,10 +728,10 @@
             }
             if (dst16 != NULL)
                 dsp->mc.warp8x8t(&dst16[x], dstride, ref_ptr, ref_stride,
-                                 wmp->abcd, mx, my);
+                                 wmp->abcd, mx, my HIGHBD_CALL_SUFFIX);
             else
                 dsp->mc.warp8x8(&dst8[x], dstride, ref_ptr, ref_stride,
-                                wmp->abcd, mx, my);
+                                wmp->abcd, mx, my HIGHBD_CALL_SUFFIX);
         }
         if (dst8) dst8  += 8 * PXSTRIDE(dstride);
         else      dst16 += 8 * dstride;
@@ -826,12 +832,14 @@
                                                           edge_flags, dst,
                                                           f->cur.stride[0], top_sb_edge,
                                                           b->y_mode, &angle,
-                                                          t_dim->w, t_dim->h, edge);
+                                                          t_dim->w, t_dim->h, edge
+                                                          HIGHBD_CALL_SUFFIX);
                     dsp->ipred.intra_pred[m](dst, f->cur.stride[0], edge,
                                              t_dim->w * 4, t_dim->h * 4,
                                              angle | intra_flags,
                                              4 * f->bw - 4 * t->bx,
-                                             4 * f->bh - 4 * t->by);
+                                             4 * f->bh - 4 * t->by
+                                             HIGHBD_CALL_SUFFIX);
 
                     if (DEBUG_BLOCK_INFO && DEBUG_B_PIXELS) {
                         hex_dump(edge - t_dim->h * 4, t_dim->h * 4,
@@ -882,7 +890,7 @@
                             dsp->itx.itxfm_add[b->tx]
                                               [txtp](dst,
                                                      f->cur.stride[0],
-                                                     cf, eob);
+                                                     cf, eob HIGHBD_CALL_SUFFIX);
                             if (DEBUG_BLOCK_INFO && DEBUG_B_PIXELS)
                                 hex_dump(dst, f->cur.stride[0],
                                          t_dim->w * 4, t_dim->h * 4, "recon");
@@ -943,11 +951,13 @@
                                                           0, uv_dst[pl], stride,
                                                           top_sb_edge, DC_PRED, &angle,
                                                           uv_t_dim->w,
-                                                          uv_t_dim->h, edge);
+                                                          uv_t_dim->h, edge
+                                                          HIGHBD_CALL_SUFFIX);
                     dsp->ipred.cfl_pred[m](uv_dst[pl], stride, edge,
                                            uv_t_dim->w * 4,
                                            uv_t_dim->h * 4,
-                                           ac, b->cfl_alpha[pl]);
+                                           ac, b->cfl_alpha[pl]
+                                           HIGHBD_CALL_SUFFIX);
                 }
                 if (DEBUG_BLOCK_INFO && DEBUG_B_PIXELS) {
                     ac_dump(ac, 4*cbw4, 4*cbh4, "ac");
@@ -1042,7 +1052,8 @@
                                                               edge_flags, dst, stride,
                                                               top_sb_edge, uv_mode,
                                                               &angle, uv_t_dim->w,
-                                                              uv_t_dim->h, edge);
+                                                              uv_t_dim->h, edge
+                                                              HIGHBD_CALL_SUFFIX);
                         angle |= intra_edge_filter_flag;
                         dsp->ipred.intra_pred[m](dst, stride, edge,
                                                  uv_t_dim->w * 4,
@@ -1051,7 +1062,8 @@
                                                  (4 * f->bw + ss_hor -
                                                   4 * (t->bx & ~ss_hor)) >> ss_hor,
                                                  (4 * f->bh + ss_ver -
-                                                  4 * (t->by & ~ss_ver)) >> ss_ver);
+                                                  4 * (t->by & ~ss_ver)) >> ss_ver
+                                                 HIGHBD_CALL_SUFFIX);
                         if (DEBUG_BLOCK_INFO && DEBUG_B_PIXELS) {
                             hex_dump(edge - uv_t_dim->h * 4, uv_t_dim->h * 4,
                                      uv_t_dim->h * 4, 2, "l");
@@ -1104,7 +1116,7 @@
                                               uv_t_dim->w * 4, 3, "dq");
                                 dsp->itx.itxfm_add[b->uvtx]
                                                   [txtp](dst, stride,
-                                                         cf, eob);
+                                                         cf, eob HIGHBD_CALL_SUFFIX);
                                 if (DEBUG_BLOCK_INFO && DEBUG_B_PIXELS)
                                     hex_dump(dst, stride, uv_t_dim->w * 4,
                                              uv_t_dim->h * 4, "recon");
@@ -1203,9 +1215,11 @@
                                                   t->by, t->by > ts->tiling.row_start,
                                                   ts->tiling.col_end, ts->tiling.row_end,
                                                   0, dst, f->cur.stride[0], top_sb_edge,
-                                                  m, &angle, bw4, bh4, tl_edge);
+                                                  m, &angle, bw4, bh4, tl_edge
+                                                  HIGHBD_CALL_SUFFIX);
             dsp->ipred.intra_pred[m](tmp, 4 * bw4 * sizeof(pixel),
-                                     tl_edge, bw4 * 4, bh4 * 4, 0, 0, 0);
+                                     tl_edge, bw4 * 4, bh4 * 4, 0, 0, 0
+                                     HIGHBD_CALL_SUFFIX);
             const uint8_t *const ii_mask =
                 b->interintra_type == INTER_INTRA_BLEND ?
                      dav1d_ii_masks[bs][0][b->interintra_mode] :
@@ -1343,9 +1357,11 @@
                                                           ts->tiling.row_end >> ss_ver,
                                                           0, uvdst, f->cur.stride[1],
                                                           top_sb_edge, m,
-                                                          &angle, cbw4, cbh4, tl_edge);
+                                                          &angle, cbw4, cbh4, tl_edge
+                                                          HIGHBD_CALL_SUFFIX);
                     dsp->ipred.intra_pred[m](tmp, cbw4 * 4 * sizeof(pixel),
-                                             tl_edge, cbw4 * 4, cbh4 * 4, 0, 0, 0);
+                                             tl_edge, cbw4 * 4, cbh4 * 4, 0, 0, 0
+                                             HIGHBD_CALL_SUFFIX);
                     dsp->mc.blend(uvdst, f->cur.stride[1], tmp,
                                   cbw4 * 4, cbh4 * 4, ii_mask);
                 }
@@ -1378,17 +1394,18 @@
         switch (b->comp_type) {
         case COMP_INTER_AVG:
             dsp->mc.avg(dst, f->cur.stride[0], tmp[0], tmp[1],
-                        bw4 * 4, bh4 * 4);
+                        bw4 * 4, bh4 * 4 HIGHBD_CALL_SUFFIX);
             break;
         case COMP_INTER_WEIGHTED_AVG:
             jnt_weight = f->jnt_weights[b->ref[0]][b->ref[1]];
             dsp->mc.w_avg(dst, f->cur.stride[0], tmp[0], tmp[1],
-                          bw4 * 4, bh4 * 4, jnt_weight);
+                          bw4 * 4, bh4 * 4, jnt_weight HIGHBD_CALL_SUFFIX);
             break;
         case COMP_INTER_SEG:
             dsp->mc.w_mask[chr_layout_idx](dst, f->cur.stride[0],
                                            tmp[b->mask_sign], tmp[!b->mask_sign],
-                                           bw4 * 4, bh4 * 4, seg_mask, b->mask_sign);
+                                           bw4 * 4, bh4 * 4, seg_mask,
+                                           b->mask_sign HIGHBD_CALL_SUFFIX);
             mask = seg_mask;
             break;
         case COMP_INTER_WEDGE:
@@ -1395,7 +1412,7 @@
             mask = dav1d_wedge_masks[bs][0][0][b->wedge_idx];
             dsp->mc.mask(dst, f->cur.stride[0],
                          tmp[b->mask_sign], tmp[!b->mask_sign],
-                         bw4 * 4, bh4 * 4, mask);
+                         bw4 * 4, bh4 * 4, mask HIGHBD_CALL_SUFFIX);
             if (has_chroma)
                 mask = dav1d_wedge_masks[bs][chr_layout_idx][b->mask_sign][b->wedge_idx];
             break;
@@ -1421,17 +1438,20 @@
             switch (b->comp_type) {
             case COMP_INTER_AVG:
                 dsp->mc.avg(uvdst, f->cur.stride[1], tmp[0], tmp[1],
-                            bw4 * 4 >> ss_hor, bh4 * 4 >> ss_ver);
+                            bw4 * 4 >> ss_hor, bh4 * 4 >> ss_ver
+                            HIGHBD_CALL_SUFFIX);
                 break;
             case COMP_INTER_WEIGHTED_AVG:
                 dsp->mc.w_avg(uvdst, f->cur.stride[1], tmp[0], tmp[1],
-                              bw4 * 4 >> ss_hor, bh4 * 4 >> ss_ver, jnt_weight);
+                              bw4 * 4 >> ss_hor, bh4 * 4 >> ss_ver, jnt_weight
+                              HIGHBD_CALL_SUFFIX);
                 break;
             case COMP_INTER_WEDGE:
             case COMP_INTER_SEG:
                 dsp->mc.mask(uvdst, f->cur.stride[1],
                              tmp[b->mask_sign], tmp[!b->mask_sign],
-                             bw4 * 4 >> ss_hor, bh4 * 4 >> ss_ver, mask);
+                             bw4 * 4 >> ss_hor, bh4 * 4 >> ss_ver, mask
+                             HIGHBD_CALL_SUFFIX);
                 break;
             }
         }
@@ -1546,7 +1566,7 @@
                             dsp->itx.itxfm_add[b->uvtx]
                                               [txtp](&uvdst[4 * x],
                                                      f->cur.stride[1],
-                                                     cf, eob);
+                                                     cf, eob HIGHBD_CALL_SUFFIX);
                             if (DEBUG_BLOCK_INFO && DEBUG_B_PIXELS)
                                 hex_dump(&uvdst[4 * x], f->cur.stride[1],
                                          uvtx->w * 4, uvtx->h * 4, "recon");
@@ -1613,7 +1633,7 @@
 
             f->dsp->mc.resize(dst, dst_stride, src, src_stride, dst_w, src_w,
                               imin(img_h, h_end) + h_start, f->resize_step[!!pl],
-                              f->resize_start[!!pl]);
+                              f->resize_start[!!pl] HIGHBD_CALL_SUFFIX);
         }
     }
     if (f->seq_hdr->restoration) {
--- a/tests/checkasm/cdef.c
+++ b/tests/checkasm/cdef.c
@@ -32,9 +32,9 @@
 #include "src/levels.h"
 #include "src/cdef.h"
 
-static void init_tmp(pixel *buf, int n) {
+static void init_tmp(pixel *buf, int n, const int bitdepth_max) {
     while (n--)
-        *buf++ = rand() & ((1 << BITDEPTH) - 1);
+        *buf++ = rand() & bitdepth_max;
 }
 
 static void check_cdef_filter(const cdef_fn fn, const int w, const int h,
@@ -48,12 +48,8 @@
 
     declare_func(void, pixel *dst, ptrdiff_t dst_stride, const pixel (*left)[2],
                  pixel *const top[2], int pri_strength, int sec_strength,
-                 int dir, int damping, enum CdefEdgeFlags edges);
+                 int dir, int damping, enum CdefEdgeFlags edges HIGHBD_DECL_SUFFIX);
 
-    init_tmp(src, 10 * 16 + 8);
-    init_tmp(top, 16 * 2 + 8);
-    init_tmp((pixel *) left,8 * 2);
-
     if (check_func(fn, "%s_%dbpc", name, BITDEPTH)) {
         for (int dir = 0; dir < 8; dir++) {
             for (enum CdefEdgeFlags edges = 0; edges <= 0xf; edges++) {
@@ -60,21 +56,35 @@
                 memcpy(a_src, src, (10 * 16 + 8) * sizeof(pixel));
                 memcpy(c_src, src, (10 * 16 + 8) * sizeof(pixel));
 
+#if BITDEPTH == 16
+                const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+                const int bitdepth_max = 0xff;
+#endif
+                const int bitdepth_min_8 = bitdepth_from_max(bitdepth_max) - 8;
+                init_tmp(src, 10 * 16 + 8, bitdepth_max);
+                init_tmp(top, 16 * 2 + 8, bitdepth_max);
+                init_tmp((pixel *) left,8 * 2, bitdepth_max);
+
                 const int lvl = 1 + (rand() % 62);
-                const int damping = 3 + (rand() & 3);
-                const int pri_strength = (lvl >> 2) << (BITDEPTH - 8);
+                const int damping = 3 + (rand() & 3) + bitdepth_min_8;
+                const int pri_strength = (lvl >> 2) << bitdepth_min_8;
                 int sec_strength = lvl & 3;
                 sec_strength += sec_strength == 3;
+                sec_strength <<= bitdepth_min_8;
                 call_ref(c_src_ptr, 16 * sizeof(pixel), left,
                          (pixel *[2]) { top_ptr, top_ptr + 16 },
-                         pri_strength, sec_strength, dir, damping, edges);
+                         pri_strength, sec_strength, dir, damping, edges
+                         HIGHBD_TAIL_SUFFIX);
                 call_new(a_src_ptr, 16 * sizeof(pixel), left,
                          (pixel *[2]) { top_ptr, top_ptr + 16 },
-                         pri_strength, sec_strength, dir, damping, edges);
+                         pri_strength, sec_strength, dir, damping, edges
+                         HIGHBD_TAIL_SUFFIX);
                 if (memcmp(a_src, c_src, (10 * 16 + 8) * sizeof(pixel))) fail();
                 bench_new(a_src_ptr, 16 * sizeof(pixel), left,
                           (pixel *[2]) { top_ptr, top_ptr + 16 },
-                          pri_strength, sec_strength, dir, damping, edges);
+                          pri_strength, sec_strength, dir, damping, edges
+                          HIGHBD_TAIL_SUFFIX);
             }
         }
     }
@@ -84,17 +94,22 @@
 static void check_cdef_direction(const cdef_dir_fn fn) {
     ALIGN_STK_32(pixel, src, 8 * 8,);
 
-    declare_func(int, pixel *src, ptrdiff_t dst_stride, unsigned *var);
+    declare_func(int, pixel *src, ptrdiff_t dst_stride, unsigned *var
+                 HIGHBD_DECL_SUFFIX);
 
-    init_tmp(src, 64);
-
     if (check_func(fn, "cdef_dir_%dbpc", BITDEPTH)) {
         unsigned c_var, a_var;
+#if BITDEPTH == 16
+        const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+        const int bitdepth_max = 0xff;
+#endif
+        init_tmp(src, 64, bitdepth_max);
 
-        const int c_dir = call_ref(src, 8 * sizeof(pixel), &c_var);
-        const int a_dir = call_new(src, 8 * sizeof(pixel), &a_var);
+        const int c_dir = call_ref(src, 8 * sizeof(pixel), &c_var HIGHBD_TAIL_SUFFIX);
+        const int a_dir = call_new(src, 8 * sizeof(pixel), &a_var HIGHBD_TAIL_SUFFIX);
         if (c_var != a_var || c_dir != a_dir) fail();
-        bench_new(src, 8 * sizeof(pixel), &a_var);
+        bench_new(src, 8 * sizeof(pixel), &a_var HIGHBD_TAIL_SUFFIX);
     }
     report("cdef_dir");
 }
--- a/tests/checkasm/checkasm.c
+++ b/tests/checkasm/checkasm.c
@@ -69,13 +69,13 @@
     { "looprestoration_8bpc", checkasm_check_looprestoration_8bpc },
     { "mc_8bpc", checkasm_check_mc_8bpc },
 #endif
-#if CONFIG_10BPC
-    { "cdef_10bpc", checkasm_check_cdef_10bpc },
-    { "ipred_10bpc", checkasm_check_ipred_10bpc },
-    { "itx_10bpc", checkasm_check_itx_10bpc },
-    { "loopfilter_10bpc", checkasm_check_loopfilter_10bpc },
-    { "looprestoration_10bpc", checkasm_check_looprestoration_10bpc },
-    { "mc_10bpc", checkasm_check_mc_10bpc },
+#if CONFIG_16BPC
+    { "cdef_16bpc", checkasm_check_cdef_16bpc },
+    { "ipred_16bpc", checkasm_check_ipred_16bpc },
+    { "itx_16bpc", checkasm_check_itx_16bpc },
+    { "loopfilter_16bpc", checkasm_check_loopfilter_16bpc },
+    { "looprestoration_16bpc", checkasm_check_looprestoration_16bpc },
+    { "mc_16bpc", checkasm_check_mc_16bpc },
 #endif
     { 0 }
 };
--- a/tests/checkasm/checkasm.h
+++ b/tests/checkasm/checkasm.h
@@ -36,23 +36,16 @@
 #include "include/common/attributes.h"
 #include "include/common/intops.h"
 
-void checkasm_check_cdef_8bpc(void);
-void checkasm_check_cdef_10bpc(void);
+#define decl_check_bitfns(name) \
+name##_8bpc(void); \
+name##_16bpc(void)
 
-void checkasm_check_ipred_8bpc(void);
-void checkasm_check_ipred_10bpc(void);
-
-void checkasm_check_itx_8bpc(void);
-void checkasm_check_itx_10bpc(void);
-
-void checkasm_check_loopfilter_8bpc(void);
-void checkasm_check_loopfilter_10bpc(void);
-
-void checkasm_check_looprestoration_8bpc(void);
-void checkasm_check_looprestoration_10bpc(void);
-
-void checkasm_check_mc_8bpc(void);
-void checkasm_check_mc_10bpc(void);
+decl_check_bitfns(void checkasm_check_cdef);
+decl_check_bitfns(void checkasm_check_ipred);
+decl_check_bitfns(void checkasm_check_itx);
+decl_check_bitfns(void checkasm_check_loopfilter);
+decl_check_bitfns(void checkasm_check_looprestoration);
+decl_check_bitfns(void checkasm_check_mc);
 
 void *checkasm_check_func(void *func, const char *name, ...);
 int checkasm_bench_func(void);
--- a/tests/checkasm/ipred.c
+++ b/tests/checkasm/ipred.c
@@ -70,7 +70,8 @@
     pixel *const topleft = topleft_buf + 128;
 
     declare_func(void, pixel *dst, ptrdiff_t stride, const pixel *topleft,
-                 int width, int height, int angle, int max_width, int max_height);
+                 int width, int height, int angle, int max_width, int max_height
+                 HIGHBD_DECL_SUFFIX);
 
     for (int mode = 0; mode < N_IMPL_INTRA_PRED_MODES; mode++)
         for (int w = 4; w <= (mode == FILTER_PRED ? 32 : 64); w <<= 1)
@@ -89,16 +90,25 @@
                     else if (mode == FILTER_PRED) /* filter_idx */
                         a = (rand() % 5) | (rand() & ~511);
 
+#if BITDEPTH == 16
+                    const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+                    const int bitdepth_max = 0xff;
+#endif
+
                     for (int i = -h * 2; i <= w * 2; i++)
-                        topleft[i] = rand() & ((1 << BITDEPTH) - 1);
+                        topleft[i] = rand() & bitdepth_max;
 
                     const int maxw = 1 + (rand() % 128), maxh = 1 + (rand() % 128);
-                    call_ref(c_dst, stride, topleft, w, h, a, maxw, maxh);
-                    call_new(a_dst, stride, topleft, w, h, a, maxw, maxh);
+                    call_ref(c_dst, stride, topleft, w, h, a, maxw, maxh
+                             HIGHBD_TAIL_SUFFIX);
+                    call_new(a_dst, stride, topleft, w, h, a, maxw, maxh
+                             HIGHBD_TAIL_SUFFIX);
                     if (memcmp(c_dst, a_dst, w * h * sizeof(*c_dst)))
                         fail();
 
-                    bench_new(a_dst, stride, topleft, w, h, a, 128, 128);
+                    bench_new(a_dst, stride, topleft, w, h, a, 128, 128
+                              HIGHBD_TAIL_SUFFIX);
                 }
             }
     report("intra_pred");
@@ -123,9 +133,14 @@
                     const ptrdiff_t stride = 32 * sizeof(pixel);
                     for (int w_pad = (w >> 2) - 1; w_pad >= 0; w_pad--) {
                         for (int h_pad = (h >> 2) - 1; h_pad >= 0; h_pad--) {
+#if BITDEPTH == 16
+                            const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+                            const int bitdepth_max = 0xff;
+#endif
                             for (int y = 0; y < (h << ss_ver); y++)
                                 for (int x = 0; x < (w << ss_hor); x++)
-                                    luma[y * 32 + x] = rand() & ((1 << BITDEPTH) - 1);
+                                    luma[y * 32 + x] = rand() & bitdepth_max;
 
                             call_ref(c_dst, luma, stride, w_pad, h_pad, w, h);
                             call_new(a_dst, luma, stride, w_pad, h_pad, w, h);
@@ -149,7 +164,8 @@
     pixel *const topleft = topleft_buf + 128;
 
     declare_func(void, pixel *dst, ptrdiff_t stride, const pixel *topleft,
-                 int width, int height, const int16_t *ac, int alpha);
+                 int width, int height, const int16_t *ac, int alpha
+                 HIGHBD_DECL_SUFFIX);
 
     for (int mode = 0; mode <= DC_128_PRED; mode += 1 + 2 * !mode)
         for (int w = 4; w <= 32; w <<= 1)
@@ -158,26 +174,35 @@
             {
                 for (int h = imax(w / 4, 4); h <= imin(w * 4, 32); h <<= 1)
                 {
+#if BITDEPTH == 16
+                    const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+                    const int bitdepth_max = 0xff;
+#endif
+
                     const ptrdiff_t stride = w * sizeof(pixel);
 
                     int alpha = ((rand() & 15) + 1) * (1 - (rand() & 2));
 
                     for (int i = -h * 2; i <= w * 2; i++)
-                        topleft[i] = rand() & ((1 << BITDEPTH) - 1);
+                        topleft[i] = rand() & bitdepth_max;
 
                     int luma_avg = w * h >> 1;
                     for (int i = 0; i < w * h; i++)
-                        luma_avg += ac[i] = rand() & ((1 << BITDEPTH) - 1) << 3;
+                        luma_avg += ac[i] = rand() & (bitdepth_max << 3);
                     luma_avg /= w * h;
                     for (int i = 0; i < w * h; i++)
                         ac[i] -= luma_avg;
 
-                    call_ref(c_dst, stride, topleft, w, h, ac, alpha);
-                    call_new(a_dst, stride, topleft, w, h, ac, alpha);
+                    call_ref(c_dst, stride, topleft, w, h, ac, alpha
+                             HIGHBD_TAIL_SUFFIX);
+                    call_new(a_dst, stride, topleft, w, h, ac, alpha
+                             HIGHBD_TAIL_SUFFIX);
                     if (memcmp(c_dst, a_dst, w * h * sizeof(*c_dst)))
                         fail();
 
-                    bench_new(a_dst, stride, topleft, w, h, ac, alpha);
+                    bench_new(a_dst, stride, topleft, w, h, ac, alpha
+                              HIGHBD_TAIL_SUFFIX);
                 }
             }
     report("cfl_pred");
@@ -196,10 +221,15 @@
         if (check_func(c->pal_pred, "pal_pred_w%d_%dbpc", w, BITDEPTH))
             for (int h = imax(w / 4, 4); h <= imin(w * 4, 64); h <<= 1)
             {
+#if BITDEPTH == 16
+                const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+                const int bitdepth_max = 0xff;
+#endif
                 const ptrdiff_t stride = w * sizeof(pixel);
 
                 for (int i = 0; i < 8; i++)
-                    pal[i] = rand() & ((1 << BITDEPTH) - 1);
+                    pal[i] = rand() & bitdepth_max;
 
                 for (int i = 0; i < w * h; i++)
                     idx[i] = rand() & 7;
--- a/tests/checkasm/itx.c
+++ b/tests/checkasm/itx.c
@@ -163,7 +163,7 @@
 
 static int ftx(coef *const buf, const enum RectTxfmSize tx,
                const enum TxfmType txtp, const int w, const int h,
-               const int subsh)
+               const int subsh, const int bitdepth_max)
 {
     double out[64 * 64], temp[64 * 64];
     const double scale = scaling_factors[ctz(w * h) - 4];
@@ -173,7 +173,7 @@
         double in[64], temp_out[64];
 
         for (int i = 0; i < w; i++)
-            in[i] = (rand() & ((2 << BITDEPTH) - 1)) - ((1 << BITDEPTH) - 1);
+            in[i] = (rand() & (2 * bitdepth_max + 1)) - bitdepth_max;
 
         switch (itx_1d_types[txtp][0]) {
         case DCT:
@@ -238,7 +238,8 @@
 
     static const uint8_t subsh_iters[5] = { 2, 2, 3, 5, 5 };
 
-    declare_func(void, pixel *dst, ptrdiff_t dst_stride, coef *coeff, int eob);
+    declare_func(void, pixel *dst, ptrdiff_t dst_stride, coef *coeff, int eob
+                 HIGHBD_DECL_SUFFIX);
 
     for (int i = 0; i < N_RECT_TX_SIZES; i++) {
         const enum RectTxfmSize tx = txfm_size_order[i];
@@ -256,16 +257,23 @@
                                itx_1d_names[itx_1d_types[txtp][1]], subsh,
                                BITDEPTH))
                 {
-                    const int eob = ftx(coeff[0], tx, txtp, w, h, subsh);
+#if BITDEPTH == 16
+                    const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+                    const int bitdepth_max = 0xff;
+#endif
+                    const int eob = ftx(coeff[0], tx, txtp, w, h, subsh, bitdepth_max);
 
                     for (int j = 0; j < w * h; j++)
-                        c_dst[j] = a_dst[j] = rand() & ((1 << BITDEPTH) - 1);
+                        c_dst[j] = a_dst[j] = rand() & bitdepth_max;
 
                     memcpy(coeff[1], coeff[0], sw * sh * sizeof(**coeff));
                     memcpy(coeff[2], coeff[0], sw * sh * sizeof(**coeff));
 
-                    call_ref(c_dst, w * sizeof(*c_dst), coeff[0], eob);
-                    call_new(a_dst, w * sizeof(*c_dst), coeff[1], eob);
+                    call_ref(c_dst, w * sizeof(*c_dst), coeff[0], eob
+                             HIGHBD_TAIL_SUFFIX);
+                    call_new(a_dst, w * sizeof(*c_dst), coeff[1], eob
+                             HIGHBD_TAIL_SUFFIX);
                     if (memcmp(c_dst, a_dst, w * h * sizeof(*c_dst)) ||
                         memcmp(coeff[0], coeff[1], sw * sh * sizeof(**coeff)))
                     {
@@ -272,7 +280,8 @@
                         fail();
                     }
 
-                    bench_new(a_dst, w * sizeof(*c_dst), coeff[2], eob);
+                    bench_new(a_dst, w * sizeof(*c_dst), coeff[2], eob
+                              HIGHBD_TAIL_SUFFIX);
                 }
         report("add_%dx%d", w, h);
     }
--- a/tests/checkasm/loopfilter.c
+++ b/tests/checkasm/loopfilter.c
@@ -33,12 +33,13 @@
 #include "src/loopfilter.h"
 
 static void init_lpf_border(pixel *const dst, const ptrdiff_t stride,
-                            int E, int I, int H)
+                            int E, int I, int H, const int bitdepth_max)
 {
-    const int F = 1 << (BITDEPTH - 8);
-    E <<= BITDEPTH - 8;
-    I <<= BITDEPTH - 8;
-    H <<= BITDEPTH - 8;
+    const int bitdepth_min_8 = bitdepth_from_max(bitdepth_max) - 8;
+    const int F = 1 << bitdepth_min_8;
+    E <<= bitdepth_min_8;
+    I <<= bitdepth_min_8;
+    H <<= bitdepth_min_8;
 
     const int filter_type = rand() % 4;
     const int edge_diff = rand() % ((E + 2) * 4) - 2 * (E + 2);
@@ -45,12 +46,12 @@
     switch (filter_type) {
     case 0: // random, unfiltered
         for (int i = -8; i < 8; i++)
-            dst[i * stride] = rand() & ((1 << BITDEPTH) - 1);
+            dst[i * stride] = rand() & bitdepth_max;
         break;
     case 1: // long flat
-        dst[-8 * stride] = rand() & ((1 << BITDEPTH) - 1);
-        dst[+7 * stride] = rand() & ((1 << BITDEPTH) - 1);
-        dst[+0 * stride] = rand() & ((1 << BITDEPTH) - 1);
+        dst[-8 * stride] = rand() & bitdepth_max;
+        dst[+7 * stride] = rand() & bitdepth_max;
+        dst[+0 * stride] = rand() & bitdepth_max;
         dst[-1 * stride] = iclip_pixel(dst[+0 * stride] + edge_diff);
         for (int i = 1; i < 7; i++) {
             dst[-(1 + i) * stride] = iclip_pixel(dst[-1 * stride] +
@@ -61,10 +62,10 @@
         break;
     case 2: // short flat
         for (int i = 4; i < 8; i++) {
-            dst[-(1 + i) * stride] = rand() & ((1 << BITDEPTH) - 1);
-            dst[+(0 + i) * stride] = rand() & ((1 << BITDEPTH) - 1);
+            dst[-(1 + i) * stride] = rand() & bitdepth_max;
+            dst[+(0 + i) * stride] = rand() & bitdepth_max;
         }
-        dst[+0 * stride] = rand() & ((1 << BITDEPTH) - 1);
+        dst[+0 * stride] = rand() & bitdepth_max;
         dst[-1 * stride] = iclip_pixel(dst[+0 * stride] + edge_diff);
         for (int i = 1; i < 4; i++) {
             dst[-(1 + i) * stride] = iclip_pixel(dst[-1 * stride] +
@@ -75,10 +76,10 @@
         break;
     case 3: // normal or hev
         for (int i = 4; i < 8; i++) {
-            dst[-(1 + i) * stride] = rand() & ((1 << BITDEPTH) - 1);
-            dst[+(0 + i) * stride] = rand() & ((1 << BITDEPTH) - 1);
+            dst[-(1 + i) * stride] = rand() & bitdepth_max;
+            dst[+(0 + i) * stride] = rand() & bitdepth_max;
         }
-        dst[+0 * stride] = rand() & ((1 << BITDEPTH) - 1);
+        dst[+0 * stride] = rand() & bitdepth_max;
         dst[-1 * stride] = iclip_pixel(dst[+0 * stride] + edge_diff);
         for (int i = 1; i < 4; i++) {
             dst[-(1 + i) * stride] = iclip_pixel(dst[-(0 + i) * stride] +
@@ -112,7 +113,7 @@
 
     declare_func(void, pixel *dst, ptrdiff_t dst_stride, const uint32_t *mask,
                  const uint8_t (*l)[4], ptrdiff_t b4_stride,
-                 const Av1FilterLUT *lut, int w);
+                 const Av1FilterLUT *lut, int w HIGHBD_DECL_SUFFIX);
 
     Av1FilterLUT lut;
     const int sharp = rand() & 7;
@@ -150,6 +151,11 @@
                     l[j * 2 + 1][lf_idx] = rand() & 63;
                 }
             }
+#if BITDEPTH == 16
+            const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+            const int bitdepth_max = 0xff;
+#endif
 
             for (int i = 0; i < 4 * n_blks; i++) {
                 const int x = i >> 2;
@@ -160,21 +166,21 @@
                     L = l[2 * x + 1][lf_idx] ? l[2 * x + 1][lf_idx] : l[2 * x][lf_idx];
                 }
                 init_lpf_border(c_dst + i * (dir ? 1 : 16), dir ? 128 : 1,
-                                lut.e[L], lut.i[L], L >> 4);
+                                lut.e[L], lut.i[L], L >> 4, bitdepth_max);
             }
             memcpy(a_dst_mem, c_dst_mem, 128 * sizeof(pixel) * 16);
 
             call_ref(c_dst, stride,
                      vmask, (const uint8_t(*)[4]) &l[dir ? 32 : 1][lf_idx], b4_stride,
-                     &lut, n_blks);
+                     &lut, n_blks HIGHBD_TAIL_SUFFIX);
             call_new(a_dst, stride,
                      vmask, (const uint8_t(*)[4]) &l[dir ? 32 : 1][lf_idx], b4_stride,
-                     &lut, n_blks);
+                     &lut, n_blks HIGHBD_TAIL_SUFFIX);
             if (memcmp(c_dst_mem, a_dst_mem, 128 * 16 * sizeof(*a_dst)))  fail();
 
             bench_new(a_dst, stride,
                       vmask, (const uint8_t(*)[4]) &l[dir ? 32 : 1][lf_idx], b4_stride,
-                      &lut, n_blks);
+                      &lut, n_blks HIGHBD_TAIL_SUFFIX);
         }
     }
     report(name);
--- a/tests/checkasm/looprestoration.c
+++ b/tests/checkasm/looprestoration.c
@@ -34,11 +34,11 @@
 #include "src/tables.h"
 
 static void init_tmp(pixel *buf, const ptrdiff_t stride,
-                     const int w, const int h)
+                     const int w, const int h, const int bitdepth_max)
 {
     for (int y = 0; y < h; y++) {
         for (int x = 0; x < w; x++)
-            buf[x] = rand() & ((1 << BITDEPTH) - 1);
+            buf[x] = rand() & bitdepth_max;
         buf += PXSTRIDE(stride);
     }
 }
@@ -65,12 +65,9 @@
                  const pixel (*const left)[4],
                  const pixel *lpf, ptrdiff_t lpf_stride,
                  int w, int h, const int16_t filterh[7],
-                 const int16_t filterv[7], enum LrEdgeFlags edges);
+                 const int16_t filterv[7], enum LrEdgeFlags edges
+                 HIGHBD_DECL_SUFFIX);
 
-    init_tmp(c_dst, 448 * sizeof(pixel), 448, 64);
-    init_tmp(h_edge, 448 * sizeof(pixel), 448, 8);
-    init_tmp((pixel *) left, 4 * sizeof(pixel), 4, 64);
-
     for (int pl = 0; pl < 2; pl++) {
         if (check_func(c->wiener, "wiener_%s_%dbpc",
                        pl ? "chroma" : "luma", BITDEPTH))
@@ -96,6 +93,16 @@
 
             const int base_w = 1 + (rand() % 384);
             const int base_h = 1 + (rand() & 63);
+#if BITDEPTH == 16
+            const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+            const int bitdepth_max = 0xff;
+#endif
+
+            init_tmp(c_dst, 448 * sizeof(pixel), 448, 64, bitdepth_max);
+            init_tmp(h_edge, 448 * sizeof(pixel), 448, 8, bitdepth_max);
+            init_tmp((pixel *) left, 4 * sizeof(pixel), 4, 64, bitdepth_max);
+
             for (enum LrEdgeFlags edges = 0; edges <= 0xf; edges++) {
                 const int w = edges & LR_HAVE_RIGHT ? 256 : base_w;
                 const int h = edges & LR_HAVE_BOTTOM ? 64 : base_h;
@@ -104,16 +111,16 @@
 
                 call_ref(c_dst + 32, 448 * sizeof(pixel), left,
                          h_edge + 32, 448 * sizeof(pixel),
-                         w, h, filter_h, filter_v, edges);
+                         w, h, filter_h, filter_v, edges HIGHBD_TAIL_SUFFIX);
                 call_new(a_dst + 32, 448 * sizeof(pixel), left,
                          h_edge + 32, 448 * sizeof(pixel),
-                         w, h, filter_h, filter_v, edges);
+                         w, h, filter_h, filter_v, edges HIGHBD_TAIL_SUFFIX);
                 const int res = cmp2d(c_dst + 32, a_dst + 32, 448 * sizeof(pixel), w, h);
                 if (res != -1) fail();
             }
             bench_new(a_dst + 32, 448 * sizeof(pixel), left,
                       h_edge + 32, 448 * sizeof(pixel),
-                      256, 64, filter_h, filter_v, 0xf);
+                      256, 64, filter_h, filter_v, 0xf HIGHBD_TAIL_SUFFIX);
         }
     }
     report("wiener");
@@ -129,12 +136,9 @@
                  const pixel (*const left)[4],
                  const pixel *lpf, ptrdiff_t lpf_stride,
                  int w, int h, int sgr_idx,
-                 const int16_t sgr_wt[7], enum LrEdgeFlags edges);
+                 const int16_t sgr_wt[7], enum LrEdgeFlags edges
+                 HIGHBD_DECL_SUFFIX);
 
-    init_tmp(c_dst, 448 * sizeof(pixel), 448, 64);
-    init_tmp(h_edge, 448 * sizeof(pixel), 448, 8);
-    init_tmp((pixel *) left, 4 * sizeof(pixel), 4, 64);
-
     for (int sgr_idx = 14; sgr_idx >= 6; sgr_idx -= 4) {
         if (check_func(c->selfguided, "selfguided_%s_%dbpc",
                        sgr_idx == 6 ? "mix" : sgr_idx == 10 ? "3x3" : "5x5", BITDEPTH))
@@ -147,6 +151,16 @@
 
             const int base_w = 1 + (rand() % 384);
             const int base_h = 1 + (rand() & 63);
+#if BITDEPTH == 16
+            const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+            const int bitdepth_max = 0xff;
+#endif
+
+            init_tmp(c_dst, 448 * sizeof(pixel), 448, 64, bitdepth_max);
+            init_tmp(h_edge, 448 * sizeof(pixel), 448, 8, bitdepth_max);
+            init_tmp((pixel *) left, 4 * sizeof(pixel), 4, 64, bitdepth_max);
+
             for (enum LrEdgeFlags edges = 0; edges <= 0xf; edges++) {
                 const int w = edges & LR_HAVE_RIGHT ? 256 : base_w;
                 const int h = edges & LR_HAVE_BOTTOM ? 64 : base_h;
@@ -155,16 +169,16 @@
 
                 call_ref(c_dst + 32, 448 * sizeof(pixel), left,
                          h_edge + 32, 448 * sizeof(pixel),
-                         w, h, sgr_idx, sgr_wt, edges);
+                         w, h, sgr_idx, sgr_wt, edges HIGHBD_TAIL_SUFFIX);
                 call_new(a_dst + 32, 448 * sizeof(pixel), left,
                          h_edge + 32, 448 * sizeof(pixel),
-                         w, h, sgr_idx, sgr_wt, edges);
+                         w, h, sgr_idx, sgr_wt, edges HIGHBD_TAIL_SUFFIX);
                 const int res = cmp2d(c_dst + 32, a_dst + 32, 448 * sizeof(pixel), w, h);
                 if (res != -1) fail();
             }
             bench_new(a_dst + 32, 448 * sizeof(pixel), left,
                       h_edge + 32, 448 * sizeof(pixel),
-                      256, 64, sgr_idx, sgr_wt, 0xf);
+                      256, 64, sgr_idx, sgr_wt, 0xf HIGHBD_TAIL_SUFFIX);
         }
     }
     report("sgr");
--- a/tests/checkasm/mc.c
+++ b/tests/checkasm/mc.c
@@ -47,11 +47,9 @@
     ALIGN_STK_32(pixel, a_dst,   128 * 128,);
     const pixel *src = src_buf + 135 * 3 + 3;
 
-    for (int i = 0; i < 135 * 135; i++)
-        src_buf[i] = rand();
-
     declare_func(void, pixel *dst, ptrdiff_t dst_stride, const pixel *src,
-                 ptrdiff_t src_stride, int w, int h, int mx, int my);
+                 ptrdiff_t src_stride, int w, int h, int mx, int my
+                 HIGHBD_DECL_SUFFIX);
 
     for (int filter = 0; filter < N_2D_FILTERS; filter++)
         for (int w = 2; w <= 128; w <<= 1)
@@ -64,15 +62,23 @@
                     for (int h = min; h <= max; h <<= 1) {
                         const int mx = (mxy & 1) ? rand() % 15 + 1 : 0;
                         const int my = (mxy & 2) ? rand() % 15 + 1 : 0;
+#if BITDEPTH == 16
+                        const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+                        const int bitdepth_max = 0xff;
+#endif
 
-                        call_ref(c_dst, w, src, w, w, h, mx, my);
-                        call_new(a_dst, w, src, w, w, h, mx, my);
+                        for (int i = 0; i < 135 * 135; i++)
+                            src_buf[i] = rand() & bitdepth_max;
+
+                        call_ref(c_dst, w, src, w, w, h, mx, my HIGHBD_TAIL_SUFFIX);
+                        call_new(a_dst, w, src, w, w, h, mx, my HIGHBD_TAIL_SUFFIX);
                         if (memcmp(c_dst, a_dst, w * h * sizeof(*c_dst)))
                             fail();
 
                         if (filter == FILTER_2D_8TAP_REGULAR ||
                             filter == FILTER_2D_BILINEAR)
-                            bench_new(a_dst, w, src, w, w, h, mx, my);
+                            bench_new(a_dst, w, src, w, w, h, mx, my HIGHBD_TAIL_SUFFIX);
                     }
                 }
     report("mc");
@@ -84,11 +90,8 @@
     ALIGN_STK_32(int16_t, a_tmp,   128 * 128,);
     const pixel *src = src_buf + 135 * 3 + 3;
 
-    for (int i = 0; i < 135 * 135; i++)
-        src_buf[i] = rand();
-
     declare_func(void, int16_t *tmp, const pixel *src, ptrdiff_t src_stride,
-                 int w, int h, int mx, int my);
+                 int w, int h, int mx, int my HIGHBD_DECL_SUFFIX);
 
     for (int filter = 0; filter < N_2D_FILTERS; filter++)
         for (int w = 4; w <= 128; w <<= 1)
@@ -99,28 +102,37 @@
                     {
                         const int mx = (mxy & 1) ? rand() % 15 + 1 : 0;
                         const int my = (mxy & 2) ? rand() % 15 + 1 : 0;
+#if BITDEPTH == 16
+                        const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+                        const int bitdepth_max = 0xff;
+#endif
 
-                        call_ref(c_tmp, src, w, w, h, mx, my);
-                        call_new(a_tmp, src, w, w, h, mx, my);
+                        for (int i = 0; i < 135 * 135; i++)
+                            src_buf[i] = rand() & bitdepth_max;
+
+                        call_ref(c_tmp, src, w, w, h, mx, my HIGHBD_TAIL_SUFFIX);
+                        call_new(a_tmp, src, w, w, h, mx, my HIGHBD_TAIL_SUFFIX);
                         if (memcmp(c_tmp, a_tmp, w * h * sizeof(*c_tmp)))
                             fail();
 
                         if (filter == FILTER_2D_8TAP_REGULAR ||
                             filter == FILTER_2D_BILINEAR)
-                            bench_new(a_tmp, src, w, w, h, mx, my);
+                            bench_new(a_tmp, src, w, w, h, mx, my HIGHBD_TAIL_SUFFIX);
                     }
     report("mct");
 }
 
 static void init_tmp(Dav1dMCDSPContext *const c, pixel *const buf,
-                     int16_t (*const tmp)[128 * 128])
+                     int16_t (*const tmp)[128 * 128], const int bitdepth_max)
 {
     for (int i = 0; i < 2; i++) {
         for (int j = 0; j < 135 * 135; j++)
-            buf[j] = rand();
+            buf[j] = rand() & bitdepth_max;
         c->mct[rand() % N_2D_FILTERS](tmp[i], buf + 135 * 3 + 3,
                                       128 * sizeof(pixel), 128, 128,
-                                      rand() & 15, rand() & 15);
+                                      rand() & 15, rand() & 15
+                                      HIGHBD_TAIL_SUFFIX);
     }
 }
 
@@ -129,21 +141,25 @@
     ALIGN_STK_32(pixel, c_dst, 135 * 135,);
     ALIGN_STK_32(pixel, a_dst, 128 * 128,);
 
-    init_tmp(c, c_dst, tmp);
-
     declare_func(void, pixel *dst, ptrdiff_t dst_stride, const int16_t *tmp1,
-                 const int16_t *tmp2, int w, int h);
+                 const int16_t *tmp2, int w, int h HIGHBD_DECL_SUFFIX);
 
     for (int w = 4; w <= 128; w <<= 1)
         if (check_func(c->avg, "avg_w%d_%dbpc", w, BITDEPTH))
             for (int h = imax(w / 4, 4); h <= imin(w * 4, 128); h <<= 1)
             {
-                call_ref(c_dst, w, tmp[0], tmp[1], w, h);
-                call_new(a_dst, w, tmp[0], tmp[1], w, h);
+#if BITDEPTH == 16
+                const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+                const int bitdepth_max = 0xff;
+#endif
+                init_tmp(c, c_dst, tmp, bitdepth_max);
+                call_ref(c_dst, w, tmp[0], tmp[1], w, h HIGHBD_TAIL_SUFFIX);
+                call_new(a_dst, w, tmp[0], tmp[1], w, h HIGHBD_TAIL_SUFFIX);
                 if (memcmp(c_dst, a_dst, w * h * sizeof(*c_dst)))
                     fail();
 
-                bench_new(a_dst, w, tmp[0], tmp[1], w, h);
+                bench_new(a_dst, w, tmp[0], tmp[1], w, h HIGHBD_TAIL_SUFFIX);
             }
     report("avg");
 }
@@ -153,10 +169,8 @@
     ALIGN_STK_32(pixel, c_dst, 135 * 135,);
     ALIGN_STK_32(pixel, a_dst, 128 * 128,);
 
-    init_tmp(c, c_dst, tmp);
-
     declare_func(void, pixel *dst, ptrdiff_t dst_stride, const int16_t *tmp1,
-                 const int16_t *tmp2, int w, int h, int weight);
+                 const int16_t *tmp2, int w, int h, int weight HIGHBD_DECL_SUFFIX);
 
     for (int w = 4; w <= 128; w <<= 1)
         if (check_func(c->w_avg, "w_avg_w%d_%dbpc", w, BITDEPTH))
@@ -163,13 +177,19 @@
             for (int h = imax(w / 4, 4); h <= imin(w * 4, 128); h <<= 1)
             {
                 int weight = rand() % 15 + 1;
+#if BITDEPTH == 16
+                const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+                const int bitdepth_max = 0xff;
+#endif
+                init_tmp(c, c_dst, tmp, bitdepth_max);
 
-                call_ref(c_dst, w, tmp[0], tmp[1], w, h, weight);
-                call_new(a_dst, w, tmp[0], tmp[1], w, h, weight);
+                call_ref(c_dst, w, tmp[0], tmp[1], w, h, weight HIGHBD_TAIL_SUFFIX);
+                call_new(a_dst, w, tmp[0], tmp[1], w, h, weight HIGHBD_TAIL_SUFFIX);
                 if (memcmp(c_dst, a_dst, w * h * sizeof(*c_dst)))
                     fail();
 
-                bench_new(a_dst, w, tmp[0], tmp[1], w, h, weight);
+                bench_new(a_dst, w, tmp[0], tmp[1], w, h, weight HIGHBD_TAIL_SUFFIX);
             }
     report("w_avg");
 }
@@ -180,23 +200,29 @@
     ALIGN_STK_32(pixel,   a_dst, 128 * 128,);
     ALIGN_STK_32(uint8_t, mask,  128 * 128,);
 
-    init_tmp(c, c_dst, tmp);
     for (int i = 0; i < 128 * 128; i++)
         mask[i] = rand() % 65;
 
     declare_func(void, pixel *dst, ptrdiff_t dst_stride, const int16_t *tmp1,
-                 const int16_t *tmp2, int w, int h, const uint8_t *mask);
+                 const int16_t *tmp2, int w, int h, const uint8_t *mask
+                 HIGHBD_DECL_SUFFIX);
 
     for (int w = 4; w <= 128; w <<= 1)
         if (check_func(c->mask, "mask_w%d_%dbpc", w, BITDEPTH))
             for (int h = imax(w / 4, 4); h <= imin(w * 4, 128); h <<= 1)
             {
-                call_ref(c_dst, w, tmp[0], tmp[1], w, h, mask);
-                call_new(a_dst, w, tmp[0], tmp[1], w, h, mask);
+#if BITDEPTH == 16
+                const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+                const int bitdepth_max = 0xff;
+#endif
+                init_tmp(c, c_dst, tmp, bitdepth_max);
+                call_ref(c_dst, w, tmp[0], tmp[1], w, h, mask HIGHBD_TAIL_SUFFIX);
+                call_new(a_dst, w, tmp[0], tmp[1], w, h, mask HIGHBD_TAIL_SUFFIX);
                 if (memcmp(c_dst, a_dst, w * h * sizeof(*c_dst)))
                     fail();
 
-                bench_new(a_dst, w, tmp[0], tmp[1], w, h, mask);
+                bench_new(a_dst, w, tmp[0], tmp[1], w, h, mask HIGHBD_TAIL_SUFFIX);
             }
     report("mask");
 }
@@ -208,10 +234,9 @@
     ALIGN_STK_32(uint8_t, c_mask, 128 * 128,);
     ALIGN_STK_32(uint8_t, a_mask, 128 * 128,);
 
-    init_tmp(c, c_dst, tmp);
-
     declare_func(void, pixel *dst, ptrdiff_t dst_stride, const int16_t *tmp1,
-                 const int16_t *tmp2, int w, int h, uint8_t *mask, int sign);
+                 const int16_t *tmp2, int w, int h, uint8_t *mask, int sign
+                 HIGHBD_DECL_SUFFIX);
 
     static const uint16_t ss[] = { 444, 422, 420 };
 
@@ -222,9 +247,17 @@
                 for (int h = imax(w / 4, 4); h <= imin(w * 4, 128); h <<= 1)
                 {
                     int sign = rand() & 1;
+#if BITDEPTH == 16
+                    const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+                    const int bitdepth_max = 0xff;
+#endif
+                    init_tmp(c, c_dst, tmp, bitdepth_max);
 
-                    call_ref(c_dst, w, tmp[0], tmp[1], w, h, c_mask, sign);
-                    call_new(a_dst, w, tmp[0], tmp[1], w, h, a_mask, sign);
+                    call_ref(c_dst, w, tmp[0], tmp[1], w, h, c_mask, sign
+                             HIGHBD_TAIL_SUFFIX);
+                    call_new(a_dst, w, tmp[0], tmp[1], w, h, a_mask, sign
+                             HIGHBD_TAIL_SUFFIX);
                     if (memcmp(c_dst, a_dst, w * h * sizeof(*c_dst)) ||
                         memcmp(c_mask, a_mask, (w * h * sizeof(*c_mask)) >> i))
                     {
@@ -231,7 +264,8 @@
                         fail();
                     }
 
-                    bench_new(a_dst, w, tmp[0], tmp[1], w, h, a_mask, sign);
+                    bench_new(a_dst, w, tmp[0], tmp[1], w, h, a_mask, sign
+                              HIGHBD_TAIL_SUFFIX);
                 }
     report("w_mask");
 }
@@ -242,11 +276,6 @@
     ALIGN_STK_32(pixel, a_dst, 32 * 32,);
     ALIGN_STK_32(uint8_t, mask, 32 * 32,);
 
-    for (int i = 0; i < 32 * 32; i++) {
-        tmp[i] = rand() & ((1 << BITDEPTH) - 1);
-        mask[i] = rand() % 65;
-    }
-
     declare_func(void, pixel *dst, ptrdiff_t dst_stride, const pixel *tmp,
                  int w, int h, const uint8_t *mask);
 
@@ -254,8 +283,17 @@
         const ptrdiff_t dst_stride = w * sizeof(pixel);
         if (check_func(c->blend, "blend_w%d_%dbpc", w, BITDEPTH))
             for (int h = imax(w / 2, 4); h <= imin(w * 2, 32); h <<= 1) {
+#if BITDEPTH == 16
+                const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+                const int bitdepth_max = 0xff;
+#endif
+                for (int i = 0; i < 32 * 32; i++) {
+                    tmp[i] = rand() & bitdepth_max;
+                    mask[i] = rand() % 65;
+                }
                 for (int i = 0; i < w * h; i++)
-                    c_dst[i] = a_dst[i] = rand() & ((1 << BITDEPTH) - 1);
+                    c_dst[i] = a_dst[i] = rand() & bitdepth_max;
 
                 call_ref(c_dst, dst_stride, tmp, w, h, mask);
                 call_new(a_dst, dst_stride, tmp, w, h, mask);
@@ -273,9 +311,6 @@
     ALIGN_STK_32(pixel, c_dst, 32 * 128,);
     ALIGN_STK_32(pixel, a_dst, 32 * 128,);
 
-    for (int i = 0; i < 32 * 128; i++)
-        tmp[i] = rand() & ((1 << BITDEPTH) - 1);
-
     declare_func(void, pixel *dst, ptrdiff_t dst_stride, const pixel *tmp,
                  int w, int h);
 
@@ -283,8 +318,16 @@
         const ptrdiff_t dst_stride = w * sizeof(pixel);
         if (check_func(c->blend_v, "blend_v_w%d_%dbpc", w, BITDEPTH))
             for (int h = 2; h <= (w == 2 ? 64 : 128); h <<= 1) {
+#if BITDEPTH == 16
+                const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+                const int bitdepth_max = 0xff;
+#endif
+
                 for (int i = 0; i < w * h; i++)
-                    c_dst[i] = a_dst[i] = rand() & ((1 << BITDEPTH) - 1);
+                    c_dst[i] = a_dst[i] = rand() & bitdepth_max;
+                for (int i = 0; i < 32 * 128; i++)
+                    tmp[i] = rand() & bitdepth_max;
 
                 call_ref(c_dst, dst_stride, tmp, w, h);
                 call_new(a_dst, dst_stride, tmp, w, h);
@@ -302,9 +345,6 @@
     ALIGN_STK_32(pixel, c_dst, 128 * 32,);
     ALIGN_STK_32(pixel, a_dst, 128 * 32,);
 
-    for (int i = 0; i < 128 * 32; i++)
-        tmp[i] = rand() & ((1 << BITDEPTH) - 1);
-
     declare_func(void, pixel *dst, ptrdiff_t dst_stride, const pixel *tmp,
                  int w, int h);
 
@@ -312,8 +352,15 @@
         const ptrdiff_t dst_stride = w * sizeof(pixel);
         if (check_func(c->blend_h, "blend_h_w%d_%dbpc", w, BITDEPTH))
             for (int h = (w == 128 ? 4 : 2); h <= 32; h <<= 1) {
+#if BITDEPTH == 16
+                const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+                const int bitdepth_max = 0xff;
+#endif
                 for (int i = 0; i < w * h; i++)
-                    c_dst[i] = a_dst[i] = rand() & ((1 << BITDEPTH) - 1);
+                    c_dst[i] = a_dst[i] = rand() & bitdepth_max;
+                for (int i = 0; i < 128 * 32; i++)
+                    tmp[i] = rand() & bitdepth_max;
 
                 call_ref(c_dst, dst_stride, tmp, w, h);
                 call_new(a_dst, dst_stride, tmp, w, h);
@@ -336,24 +383,30 @@
     const ptrdiff_t src_stride = 15 * sizeof(pixel);
 
     declare_func(void, pixel *dst, ptrdiff_t dst_stride, const pixel *src,
-                 ptrdiff_t src_stride, const int16_t *abcd, int mx, int my);
+                 ptrdiff_t src_stride, const int16_t *abcd, int mx, int my
+                 HIGHBD_DECL_SUFFIX);
 
     if (check_func(c->warp8x8, "warp_8x8_%dbpc", BITDEPTH)) {
         const int mx = (rand() & 0x1fff) - 0x800;
         const int my = (rand() & 0x1fff) - 0x800;
+#if BITDEPTH == 16
+        const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+        const int bitdepth_max = 0xff;
+#endif
 
         for (int i = 0; i < 4; i++)
             abcd[i] = (rand() & 0x1fff) - 0x800;
 
         for (int i = 0; i < 15 * 15; i++)
-            src_buf[i] = rand() & ((1 << BITDEPTH) - 1);
+            src_buf[i] = rand() & bitdepth_max;
 
-        call_ref(c_dst, dst_stride, src, src_stride, abcd, mx, my);
-        call_new(a_dst, dst_stride, src, src_stride, abcd, mx, my);
+        call_ref(c_dst, dst_stride, src, src_stride, abcd, mx, my HIGHBD_TAIL_SUFFIX);
+        call_new(a_dst, dst_stride, src, src_stride, abcd, mx, my HIGHBD_TAIL_SUFFIX);
         if (memcmp(c_dst, a_dst, 8 * 8 * sizeof(*c_dst)))
             fail();
 
-        bench_new(a_dst, dst_stride, src, src_stride, abcd, mx, my);
+        bench_new(a_dst, dst_stride, src, src_stride, abcd, mx, my HIGHBD_TAIL_SUFFIX);
     }
     report("warp8x8");
 }
@@ -367,24 +420,30 @@
     const ptrdiff_t src_stride = 15 * sizeof(pixel);
 
     declare_func(void, int16_t *tmp, ptrdiff_t tmp_stride, const pixel *src,
-                 ptrdiff_t src_stride, const int16_t *abcd, int mx, int my);
+                 ptrdiff_t src_stride, const int16_t *abcd, int mx, int my
+                 HIGHBD_DECL_SUFFIX);
 
     if (check_func(c->warp8x8t, "warp_8x8t_%dbpc", BITDEPTH)) {
         const int mx = (rand() & 0x1fff) - 0x800;
         const int my = (rand() & 0x1fff) - 0x800;
+#if BITDEPTH == 16
+        const int bitdepth_max = rand() & 1 ? 0x3ff : 0xfff;
+#else
+        const int bitdepth_max = 0xff;
+#endif
 
         for (int i = 0; i < 4; i++)
             abcd[i] = (rand() & 0x1fff) - 0x800;
 
         for (int i = 0; i < 15 * 15; i++)
-            src_buf[i] = rand() & ((1 << BITDEPTH) - 1);
+            src_buf[i] = rand() & bitdepth_max;
 
-        call_ref(c_tmp, 8, src, src_stride, abcd, mx, my);
-        call_new(a_tmp, 8, src, src_stride, abcd, mx, my);
+        call_ref(c_tmp, 8, src, src_stride, abcd, mx, my HIGHBD_TAIL_SUFFIX);
+        call_new(a_tmp, 8, src, src_stride, abcd, mx, my HIGHBD_TAIL_SUFFIX);
         if (memcmp(c_tmp, a_tmp, 8 * 8 * sizeof(*c_tmp)))
             fail();
 
-        bench_new(a_tmp, 8, src, src_stride, abcd, mx, my);
+        bench_new(a_tmp, 8, src, src_stride, abcd, mx, my HIGHBD_TAIL_SUFFIX);
     }
     report("warp8x8t");
 }