shithub: dav1d

Download patch

ref: 58fc51659634b48026da97eced714d214c97857a
parent: 8b8e9fe85f6875a86ed66726e8964450a318cdc6
author: Henrik Gramner <[email protected]>
date: Fri Nov 9 15:18:18 EST 2018

Split MC blend

The mstride == 0, mstride == 1, and mstride == w cases are very different
from each other, and splitting them into separate functions makes it easier
top optimize them.

Also add some further optimizations to the AVX2 asm that became possible
after this change.

--- a/src/mc.h
+++ b/src/mc.h
@@ -81,11 +81,14 @@
 typedef decl_w_mask_fn(*w_mask_fn);
 
 #define decl_blend_fn(name) \
-void (name)(pixel *dst, ptrdiff_t dst_stride, \
-            const pixel *tmp, int w, int h, \
-            const uint8_t *mask, ptrdiff_t mstride)
+void (name)(pixel *dst, ptrdiff_t dst_stride, const pixel *tmp, \
+            int w, int h, const uint8_t *mask)
 typedef decl_blend_fn(*blend_fn);
 
+#define decl_blend_dir_fn(name) \
+void (name)(pixel *dst, ptrdiff_t dst_stride, const pixel *tmp, int w, int h)
+typedef decl_blend_dir_fn(*blend_dir_fn);
+
 #define decl_emu_edge_fn(name) \
 void (name)(intptr_t bw, intptr_t bh, intptr_t iw, intptr_t ih, intptr_t x, intptr_t y, \
             pixel *dst, ptrdiff_t dst_stride, const pixel *src, ptrdiff_t src_stride)
@@ -99,6 +102,8 @@
     mask_fn mask;
     w_mask_fn w_mask[3 /* 444, 422, 420 */];
     blend_fn blend;
+    blend_dir_fn blend_v;
+    blend_dir_fn blend_h;
     warp8x8_fn warp8x8;
     warp8x8t_fn warp8x8t;
     emu_edge_fn emu_edge;
--- a/src/mc_tmpl.c
+++ b/src/mc_tmpl.c
@@ -373,21 +373,48 @@
     } while (--h);
 }
 
-static void blend_c(pixel *dst, const ptrdiff_t dst_stride,
-                    const pixel *tmp, const int w, const int h,
-                    const uint8_t *mask, const ptrdiff_t m_stride)
+#define blend_px(a, b, m) (((a * (64 - m) + b * m) + 32) >> 6)
+static NOINLINE void
+blend_internal_c(pixel *dst, const ptrdiff_t dst_stride, const pixel *tmp,
+                 const int w, int h, const uint8_t *mask,
+                 const ptrdiff_t mask_stride)
 {
-    for (int y = 0; y < h; y++) {
+    do {
         for (int x = 0; x < w; x++) {
-#define blend_px(a, b, m) (((a * (64 - m) + b * m) + 32) >> 6)
-            dst[x] = blend_px(dst[x], tmp[x], mask[m_stride == 1 ? 0 : x]);
+            dst[x] = blend_px(dst[x], tmp[x], mask[x]);
         }
         dst += PXSTRIDE(dst_stride);
         tmp += w;
-        mask += m_stride;
-    }
+        mask += mask_stride;
+    } while (--h);
 }
 
+static void blend_c(pixel *dst, const ptrdiff_t dst_stride, const pixel *tmp,
+                    const int w, const int h, const uint8_t *mask)
+{
+    blend_internal_c(dst, dst_stride, tmp, w, h, mask, w);
+}
+
+static void blend_v_c(pixel *dst, const ptrdiff_t dst_stride, const pixel *tmp,
+                      const int w, const int h)
+{
+    blend_internal_c(dst, dst_stride, tmp, w, h, &dav1d_obmc_masks[w], 0);
+}
+
+static void blend_h_c(pixel *dst, const ptrdiff_t dst_stride, const pixel *tmp,
+                      const int w, int h)
+{
+    const uint8_t *mask = &dav1d_obmc_masks[h];
+    do {
+        const int m = *mask++;
+        for (int x = 0; x < w; x++) {
+            dst[x] = blend_px(dst[x], tmp[x], m);
+        }
+        dst += PXSTRIDE(dst_stride);
+        tmp += w;
+    } while (--h);
+}
+
 static void w_mask_c(pixel *dst, const ptrdiff_t dst_stride,
                      const coef *tmp1, const coef *tmp2, const int w, int h,
                      uint8_t *mask, const int sign,
@@ -591,6 +618,8 @@
     c->w_avg    = w_avg_c;
     c->mask     = mask_c;
     c->blend    = blend_c;
+    c->blend_v  = blend_v_c;
+    c->blend_h  = blend_h_c;
     c->w_mask[0] = w_mask_444_c;
     c->w_mask[1] = w_mask_422_c;
     c->w_mask[2] = w_mask_420_c;
--- a/src/recon_tmpl.c
+++ b/src/recon_tmpl.c
@@ -579,9 +579,8 @@
                          &f->refp[a_r->ref[0] - 1],
                          dav1d_filter_2d[t->a->filter[1][bx4 + x + 1]][t->a->filter[0][bx4 + x + 1]]);
                 if (res) return res;
-                f->dsp->mc.blend(&dst[x * h_mul], dst_stride, lap,
-                                 h_mul * ow4, v_mul * oh4,
-                                 &dav1d_obmc_masks[v_mul * oh4], 1);
+                f->dsp->mc.blend_h(&dst[x * h_mul], dst_stride, lap,
+                                   h_mul * ow4, v_mul * oh4);
                 i++;
             }
             x += imax(a_b_dim[0], 2);
@@ -603,9 +602,8 @@
                          &f->refp[l_r->ref[0] - 1],
                          dav1d_filter_2d[t->l.filter[1][by4 + y + 1]][t->l.filter[0][by4 + y + 1]]);
                 if (res) return res;
-                f->dsp->mc.blend(&dst[y * v_mul * PXSTRIDE(dst_stride)],
-                                 dst_stride, lap, h_mul * ow4, v_mul * oh4,
-                                 &dav1d_obmc_masks[h_mul * ow4], 0);
+                f->dsp->mc.blend_v(&dst[y * v_mul * PXSTRIDE(dst_stride)],
+                                   dst_stride, lap, h_mul * ow4, v_mul * oh4);
                 i++;
             }
             y += imax(l_b_dim[1], 2);
@@ -1144,7 +1142,7 @@
                      dav1d_ii_masks[bs][0][b->interintra_mode] :
                      dav1d_wedge_masks[bs][0][0][b->wedge_idx];
             dsp->mc.blend(dst, f->cur.p.stride[0], tmp,
-                          bw4 * 4, bh4 * 4, ii_mask, bw4 * 4);
+                          bw4 * 4, bh4 * 4, ii_mask);
         }
 
         if (!has_chroma) goto skip_inter_chroma_pred;
@@ -1277,7 +1275,7 @@
                     dsp->ipred.intra_pred[m](tmp, cbw4 * 4 * sizeof(pixel),
                                              tl_edge, cbw4 * 4, cbh4 * 4, 0);
                     dsp->mc.blend(uvdst, f->cur.p.stride[1], tmp,
-                                  cbw4 * 4, cbh4 * 4, ii_mask, cbw4 * 4);
+                                  cbw4 * 4, cbh4 * 4, ii_mask);
                 }
             }
         }
--- a/src/x86/mc.asm
+++ b/src/x86/mc.asm
@@ -30,6 +30,23 @@
 
 SECTION_RODATA 32
 
+; dav1d_obmc_masks[] with 64-x interleaved
+obmc_masks: db  0,  0,  0,  0
+            ; 2
+            db 45, 19, 64,  0
+            ; 4
+            db 39, 25, 50, 14, 59,  5, 64,  0
+            ; 8
+            db 36, 28, 42, 22, 48, 16, 53, 11, 57,  7, 61,  3, 64,  0, 64,  0
+            ; 16
+            db 34, 30, 37, 27, 40, 24, 43, 21, 46, 18, 49, 15, 52, 12, 54, 10
+            db 56,  8, 58,  6, 60,  4, 61,  3, 64,  0, 64,  0, 64,  0, 64,  0
+            ; 32
+            db 33, 31, 35, 29, 36, 28, 38, 26, 40, 24, 41, 23, 43, 21, 44, 20
+            db 45, 19, 47, 17, 48, 16, 50, 14, 51, 13, 52, 12, 53, 11, 55,  9
+            db 56,  8, 57,  7, 58,  6, 59,  5, 60,  4, 60,  4, 61,  3, 62,  2
+            db 64,  0, 64,  0, 64,  0, 64,  0, 64,  0, 64,  0, 64,  0, 64,  0
+
 warp_8x8_shufA: db 0,  2,  4,  6,  1,  3,  5,  7,  1,  3,  5,  7,  2,  4,  6,  8
                 db 4,  6,  8, 10,  5,  7,  9, 11,  5,  7,  9, 11,  6,  8, 10, 12
 warp_8x8_shufB: db 2,  4,  6,  8,  3,  5,  7,  9,  3,  5,  7,  9,  4,  6,  8, 10
@@ -42,10 +59,9 @@
 bilin_h_shuf4:  db 1,  0,  2,  1,  3,  2,  4,  3,  9,  8, 10,  9, 11, 10, 12, 11
 bilin_h_shuf8:  db 1,  0,  2,  1,  3,  2,  4,  3,  5,  4,  6,  5,  7,  6,  8,  7
 deint_shuf4:    db 0,  4,  1,  5,  2,  6,  3,  7,  4,  8,  5,  9,  6, 10,  7, 11
+blend_shuf:     db 0,  1,  0,  1,  0,  1,  0,  1,  2,  3,  2,  3,  2,  3,  2,  3
 
-blend_shuf: ; bits 0-3: 0, 0, 0, 0, 1, 1, 1, 1
 pb_64:   times 4 db 64
-         times 4 db 1
 pw_8:    times 2 dw 8
 pw_26:   times 2 dw 26
 pw_34:   times 2 dw 34
@@ -61,7 +77,7 @@
 cextern mc_subpel_filters
 %define subpel_filters (mangle(private_prefix %+ _mc_subpel_filters)-8)
 
-%macro BIDIR_JMP_TABLE 1-* 4, 8, 16, 32, 64, 128
+%macro BIDIR_JMP_TABLE 1-*
     %xdefine %1_table (%%table - 2*%2)
     %xdefine %%base %1_table
     %xdefine %%prefix mangle(private_prefix %+ _%1)
@@ -72,11 +88,13 @@
     %endrep
 %endmacro
 
-BIDIR_JMP_TABLE avg_avx2
-BIDIR_JMP_TABLE w_avg_avx2
-BIDIR_JMP_TABLE mask_avx2
-BIDIR_JMP_TABLE w_mask_420_avx2
-BIDIR_JMP_TABLE blend_avx2, 2, 4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE avg_avx2,        4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE w_avg_avx2,      4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE mask_avx2,       4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE w_mask_420_avx2, 4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE blend_avx2,      4, 8, 16, 32
+BIDIR_JMP_TABLE blend_v_avx2, 2, 4, 8, 16, 32
+BIDIR_JMP_TABLE blend_h_avx2, 2, 4, 8, 16, 32, 32, 32
 
 %macro BASE_JMP_TABLE 3-*
     %xdefine %1_%2_table (%%table - %3)
@@ -3286,7 +3304,7 @@
     jg .w128_loop
     RET
 
-cglobal blend, 3, 7, 6, dst, ds, tmp, w, h, mask, ms
+cglobal blend, 3, 7, 7, dst, ds, tmp, w, h, mask
 %define base r6-blend_avx2_table
     lea                  r6, [blend_avx2_table]
     tzcnt                wd, wm
@@ -3296,55 +3314,125 @@
     vpbroadcastd         m4, [base+pb_64]
     vpbroadcastd         m5, [base+pw_512]
     add                  wq, r6
-    mov                 msq, msmp
+    lea                  r6, [dsq*3]
     jmp                  wq
-.w2:
-    cmp                 msq, 1
-    jb .w2_s0
-    je .w2_s1
-.w2_s2:
-    movd                xm1, [maskq]
+.w4:
     movd                xm0, [dstq+dsq*0]
-    pinsrw              xm0, [dstq+dsq*1], 1
-    psubb               xm2, xm4, xm1
-    punpcklbw           xm2, xm1
-    movd                xm1, [tmpq]
-    add               maskq, 2*2
-    add                tmpq, 2*2
-    punpcklbw           xm0, xm1
+    pinsrd              xm0, [dstq+dsq*1], 1
+    vpbroadcastd        xm1, [dstq+dsq*2]
+    pinsrd              xm1, [dstq+r6   ], 3
+    mova                xm6, [maskq]
+    psubb               xm3, xm4, xm6
+    punpcklbw           xm2, xm3, xm6
+    punpckhbw           xm3, xm6
+    mova                xm6, [tmpq]
+    add               maskq, 4*4
+    add                tmpq, 4*4
+    punpcklbw           xm0, xm6
+    punpckhbw           xm1, xm6
     pmaddubsw           xm0, xm2
+    pmaddubsw           xm1, xm3
     pmulhrsw            xm0, xm5
-    packuswb            xm0, xm0
-    pextrw     [dstq+dsq*0], xm0, 0
-    pextrw     [dstq+dsq*1], xm0, 1
-    lea                dstq, [dstq+dsq*2]
-    sub                  hd, 2
-    jg .w2_s2
+    pmulhrsw            xm1, xm5
+    packuswb            xm0, xm1
+    movd       [dstq+dsq*0], xm0
+    pextrd     [dstq+dsq*1], xm0, 1
+    pextrd     [dstq+dsq*2], xm0, 2
+    pextrd     [dstq+r6   ], xm0, 3
+    lea                dstq, [dstq+dsq*4]
+    sub                  hd, 4
+    jg .w4
     RET
-.w2_s1:
-    movd                xm1, [maskq]
-    movd                xm0, [dstq+dsq*0]
-    psubb               xm2, xm4, xm1
-    punpcklbw           xm2, xm1
-    pinsrw              xm0, [dstq+dsq*1], 1
-    movd                xm1, [tmpq]
-    punpcklwd           xm2, xm2
-    add               maskq, 2
-    add                tmpq, 2*2
-    punpcklbw           xm0, xm1
-    pmaddubsw           xm0, xm2
-    pmulhrsw            xm0, xm5
-    packuswb            xm0, xm0
-    pextrw     [dstq+dsq*0], xm0, 0
-    pextrw     [dstq+dsq*1], xm0, 1
+ALIGN function_align
+.w8:
+    movq                xm1, [dstq+dsq*0]
+    movhps              xm1, [dstq+dsq*1]
+    vpbroadcastq         m2, [dstq+dsq*2]
+    vpbroadcastq         m3, [dstq+r6   ]
+    mova                 m0, [maskq]
+    mova                 m6, [tmpq]
+    add               maskq, 8*4
+    add                tmpq, 8*4
+    vpblendd             m1, m2, 0x30
+    vpblendd             m1, m3, 0xc0
+    psubb                m3, m4, m0
+    punpcklbw            m2, m3, m0
+    punpckhbw            m3, m0
+    punpcklbw            m0, m1, m6
+    punpckhbw            m1, m6
+    pmaddubsw            m0, m2
+    pmaddubsw            m1, m3
+    pmulhrsw             m0, m5
+    pmulhrsw             m1, m5
+    packuswb             m0, m1
+    vextracti128        xm1, m0, 1
+    movq       [dstq+dsq*0], xm0
+    movhps     [dstq+dsq*1], xm0
+    movq       [dstq+dsq*2], xm1
+    movhps     [dstq+r6   ], xm1
+    lea                dstq, [dstq+dsq*4]
+    sub                  hd, 4
+    jg .w8
+    RET
+ALIGN function_align
+.w16:
+    mova                 m0, [maskq]
+    mova                xm1, [dstq+dsq*0]
+    vinserti128          m1, [dstq+dsq*1], 1
+    psubb                m3, m4, m0
+    punpcklbw            m2, m3, m0
+    punpckhbw            m3, m0
+    mova                 m6, [tmpq]
+    add               maskq, 16*2
+    add                tmpq, 16*2
+    punpcklbw            m0, m1, m6
+    punpckhbw            m1, m6
+    pmaddubsw            m0, m2
+    pmaddubsw            m1, m3
+    pmulhrsw             m0, m5
+    pmulhrsw             m1, m5
+    packuswb             m0, m1
+    mova         [dstq+dsq*0], xm0
+    vextracti128 [dstq+dsq*1], m0, 1
     lea                dstq, [dstq+dsq*2]
     sub                  hd, 2
-    jg .w2_s1
+    jg .w16
     RET
-.w2_s0:
-    vpbroadcastw        xm0, [maskq]
-    psubb               xm4, xm0
-    punpcklbw           xm4, xm0
+ALIGN function_align
+.w32:
+    mova                 m0, [maskq]
+    mova                 m1, [dstq]
+    mova                 m6, [tmpq]
+    add               maskq, 32
+    add                tmpq, 32
+    psubb                m3, m4, m0
+    punpcklbw            m2, m3, m0
+    punpckhbw            m3, m0
+    punpcklbw            m0, m1, m6
+    punpckhbw            m1, m6
+    pmaddubsw            m0, m2
+    pmaddubsw            m1, m3
+    pmulhrsw             m0, m5
+    pmulhrsw             m1, m5
+    packuswb             m0, m1
+    mova             [dstq], m0
+    add                dstq, dsq
+    dec                  hd
+    jg .w32
+    RET
+
+cglobal blend_v, 3, 6, 6, dst, ds, tmp, w, h, mask
+%define base r5-blend_v_avx2_table
+    lea                  r5, [blend_v_avx2_table]
+    tzcnt                wd, wm
+    movifnidn            hd, hm
+    movsxd               wq, dword [r5+wq*4]
+    vpbroadcastd         m5, [base+pw_512]
+    add                  wq, r5
+    add               maskq, obmc_masks-blend_v_avx2_table
+    jmp                  wq
+.w2:
+    vpbroadcastd        xm2, [maskq+2*2]
 .w2_s0_loop:
     movd                xm0, [dstq+dsq*0]
     pinsrw              xm0, [dstq+dsq*1], 1
@@ -3351,7 +3439,7 @@
     movd                xm1, [tmpq]
     add                tmpq, 2*2
     punpcklbw           xm0, xm1
-    pmaddubsw           xm0, xm4
+    pmaddubsw           xm0, xm2
     pmulhrsw            xm0, xm5
     packuswb            xm0, xm0
     pextrw     [dstq+dsq*0], xm0, 0
@@ -3362,17 +3450,11 @@
     RET
 ALIGN function_align
 .w4:
-    cmp                 msq, 1
-    jb .w4_s0
-    je .w4_s1
-.w4_s4:
-    movq                xm1, [maskq]
+    vpbroadcastq        xm2, [maskq+4*2]
+.w4_loop:
     movd                xm0, [dstq+dsq*0]
     pinsrd              xm0, [dstq+dsq*1], 1
-    psubb               xm2, xm4, xm1
-    punpcklbw           xm2, xm1
     movq                xm1, [tmpq]
-    add               maskq, 4*2
     add                tmpq, 4*2
     punpcklbw           xm0, xm1
     pmaddubsw           xm0, xm2
@@ -3382,116 +3464,19 @@
     pextrd     [dstq+dsq*1], xm0, 1
     lea                dstq, [dstq+dsq*2]
     sub                  hd, 2
-    jg .w4_s4
+    jg .w4_loop
     RET
-.w4_s1:
-    movq                xm3, [blend_shuf]
-.w4_s1_loop:
-    movd                xm1, [maskq]
-    movd                xm0, [dstq+dsq*0]
-    pshufb              xm1, xm3
-    psubb               xm2, xm4, xm1
-    pinsrd              xm0, [dstq+dsq*1], 1
-    punpcklbw           xm2, xm1
-    movq                xm1, [tmpq]
-    add               maskq, 2
-    add                tmpq, 4*2
-    punpcklbw           xm0, xm1
-    pmaddubsw           xm0, xm2
-    pmulhrsw            xm0, xm5
-    packuswb            xm0, xm0
-    movd       [dstq+dsq*0], xm0
-    pextrd     [dstq+dsq*1], xm0, 1
-    lea                dstq, [dstq+dsq*2]
-    sub                  hd, 2
-    jg .w4_s1_loop
-    RET
-.w4_s0:
-    vpbroadcastd        xm0, [maskq]
-    psubb               xm4, xm0
-    punpcklbw           xm4, xm0
-.w4_s0_loop:
-    movd                xm0, [dstq+dsq*0]
-    pinsrd              xm0, [dstq+dsq*1], 1
-    movq                xm1, [tmpq]
-    add                tmpq, 4*2
-    punpcklbw           xm0, xm1
-    pmaddubsw           xm0, xm4
-    pmulhrsw            xm0, xm5
-    packuswb            xm0, xm0
-    movd       [dstq+dsq*0], xm0
-    pextrd     [dstq+dsq*1], xm0, 1
-    lea                dstq, [dstq+dsq*2]
-    sub                  hd, 2
-    jg .w4_s0_loop
-    RET
 ALIGN function_align
 .w8:
-    cmp                 msq, 1
-    jb .w8_s0
-    je .w8_s1
-.w8_s8:
-    movq                xm1, [maskq+8*1]
-    vinserti128          m1, [maskq+8*0], 1
+    vbroadcasti128       m4, [maskq+8*2]
+.w8_loop:
     vpbroadcastq         m2, [dstq+dsq*0]
     movq                xm0, [dstq+dsq*1]
     vpblendd             m0, m2, 0x30
-    psubb                m2, m4, m1
-    punpcklbw            m2, m1
     movq                xm1, [tmpq+8*1]
     vinserti128          m1, [tmpq+8*0], 1
-    add               maskq, 8*2
     add                tmpq, 8*2
     punpcklbw            m0, m1
-    pmaddubsw            m0, m2
-    pmulhrsw             m0, m5
-    vextracti128        xm1, m0, 1
-    packuswb            xm0, xm1
-    movhps     [dstq+dsq*0], xm0
-    movq       [dstq+dsq*1], xm0
-    lea                dstq, [dstq+dsq*2]
-    sub                  hd, 2
-    jg .w8_s8
-    RET
-.w8_s1:
-    vpbroadcastd         m0, [blend_shuf+0]
-    vpbroadcastd        xm3, [blend_shuf+4]
-    vpblendd             m3, m0, 0xf0
-.w8_s1_loop:
-    vpbroadcastd         m0, [maskq]
-    vpbroadcastq         m1, [dstq+dsq*0]
-    pshufb               m0, m3
-    psubb                m2, m4, m0
-    punpcklbw            m2, m0
-    movq                xm0, [dstq+dsq*1]
-    vpblendd             m0, m1, 0x30
-    movq                xm1, [tmpq+8*1]
-    vinserti128          m1, [tmpq+8*0], 1
-    add               maskq, 2
-    add                tmpq, 8*2
-    punpcklbw            m0, m1
-    pmaddubsw            m0, m2
-    pmulhrsw             m0, m5
-    vextracti128        xm1, m0, 1
-    packuswb            xm0, xm1
-    movhps     [dstq+dsq*0], xm0
-    movq       [dstq+dsq*1], xm0
-    lea                dstq, [dstq+dsq*2]
-    sub                  hd, 2
-    jg .w8_s1_loop
-    RET
-.w8_s0:
-    vpbroadcastq         m0, [maskq]
-    psubb                m4, m0
-    punpcklbw            m4, m0
-.w8_s0_loop:
-    vpbroadcastq         m2, [dstq+dsq*0]
-    movq                xm0, [dstq+dsq*1]
-    vpblendd             m0, m2, 0x30
-    movq                xm1, [tmpq+8*1]
-    vinserti128          m1, [tmpq+8*0], 1
-    add                tmpq, 8*2
-    punpcklbw            m0, m1
     pmaddubsw            m0, m4
     pmulhrsw             m0, m5
     vextracti128        xm1, m0, 1
@@ -3500,28 +3485,21 @@
     movq       [dstq+dsq*1], xm0
     lea                dstq, [dstq+dsq*2]
     sub                  hd, 2
-    jg .w8_s0_loop
+    jg .w8_loop
     RET
 ALIGN function_align
 .w16:
-    cmp                 msq, 1
-    jb .w16_s0
-    WIN64_SPILL_XMM       7
-    je .w16_s1
-.w16_s16:
-    mova                 m0, [maskq]
+    vbroadcasti128       m3, [maskq+16*2]
+    vbroadcasti128       m4, [maskq+16*3]
+.w16_loop:
     mova                xm1, [dstq+dsq*0]
     vinserti128          m1, [dstq+dsq*1], 1
-    psubb                m3, m4, m0
-    punpcklbw            m2, m3, m0
-    punpckhbw            m3, m0
-    mova                 m6, [tmpq]
-    add               maskq, 16*2
+    mova                 m2, [tmpq]
     add                tmpq, 16*2
-    punpcklbw            m0, m1, m6
-    punpckhbw            m1, m6
-    pmaddubsw            m0, m2
-    pmaddubsw            m1, m3
+    punpcklbw            m0, m1, m2
+    punpckhbw            m1, m2
+    pmaddubsw            m0, m3
+    pmaddubsw            m1, m4
     pmulhrsw             m0, m5
     pmulhrsw             m1, m5
     packuswb             m0, m1
@@ -3529,51 +3507,119 @@
     vextracti128 [dstq+dsq*1], m0, 1
     lea                dstq, [dstq+dsq*2]
     sub                  hd, 2
-    jg .w16_s16
+    jg .w16_loop
     RET
-.w16_s1:
-    vpbroadcastd        xm6, [blend_shuf]
-    vpbroadcastd         m0, [blend_shuf+4]
-    vpblendd             m6, m0, 0xf0
-.w16_s1_loop:
-    vpbroadcastd         m2, [maskq]
-    mova                xm1, [dstq+dsq*0]
-    pshufb               m2, m6
-    psubb                m3, m4, m2
-    vinserti128          m1, [dstq+dsq*1], 1
-    punpcklbw            m3, m2
+ALIGN function_align
+.w32:
+    mova                xm3, [maskq+16*4]
+    vinserti128          m3, [maskq+16*6], 1
+    mova                xm4, [maskq+16*5]
+    vinserti128          m4, [maskq+16*7], 1
+.w32_loop:
+    mova                 m1, [dstq]
     mova                 m2, [tmpq]
-    add               maskq, 2
-    add                tmpq, 16*2
+    add                tmpq, 32
     punpcklbw            m0, m1, m2
     punpckhbw            m1, m2
     pmaddubsw            m0, m3
-    pmaddubsw            m1, m3
+    pmaddubsw            m1, m4
     pmulhrsw             m0, m5
     pmulhrsw             m1, m5
     packuswb             m0, m1
-    mova         [dstq+dsq*0], xm0
-    vextracti128 [dstq+dsq*1], m0, 1
+    mova             [dstq], m0
+    add                dstq, dsq
+    dec                  hd
+    jg .w32_loop
+    RET
+
+cglobal blend_h, 4, 7, 6, dst, ds, tmp, w, h, mask
+%define base r5-blend_h_avx2_table
+    lea                  r5, [blend_h_avx2_table]
+    mov                 r6d, wd
+    tzcnt                wd, wd
+    mov                  hd, hm
+    movsxd               wq, dword [r5+wq*4]
+    vpbroadcastd         m5, [base+pw_512]
+    add                  wq, r5
+    lea               maskq, [base+obmc_masks+hq*4]
+    neg                  hq
+    jmp                  wq
+.w2:
+    movd                xm0, [dstq+dsq*0]
+    pinsrw              xm0, [dstq+dsq*1], 1
+    movd                xm2, [maskq+hq*2]
+    movd                xm1, [tmpq]
+    add                tmpq, 2*2
+    punpcklwd           xm2, xm2
+    punpcklbw           xm0, xm1
+    pmaddubsw           xm0, xm2
+    pmulhrsw            xm0, xm5
+    packuswb            xm0, xm0
+    pextrw     [dstq+dsq*0], xm0, 0
+    pextrw     [dstq+dsq*1], xm0, 1
     lea                dstq, [dstq+dsq*2]
-    sub                  hd, 2
-    jg .w16_s1_loop
+    add                  hq, 2
+    jl .w2
     RET
-.w16_s0:
-    %assign stack_offset stack_offset - stack_size_padded
-    WIN64_SPILL_XMM       6
-    vbroadcasti128       m0, [maskq]
-    psubb                m4, m0
-    punpcklbw            m3, m4, m0
-    punpckhbw            m4, m0
-.w16_s0_loop:
+ALIGN function_align
+.w4:
+    mova                xm3, [blend_shuf]
+.w4_loop:
+    movd                xm0, [dstq+dsq*0]
+    pinsrd              xm0, [dstq+dsq*1], 1
+    movq                xm2, [maskq+hq*2]
+    movq                xm1, [tmpq]
+    add                tmpq, 4*2
+    pshufb              xm2, xm3
+    punpcklbw           xm0, xm1
+    pmaddubsw           xm0, xm2
+    pmulhrsw            xm0, xm5
+    packuswb            xm0, xm0
+    movd       [dstq+dsq*0], xm0
+    pextrd     [dstq+dsq*1], xm0, 1
+    lea                dstq, [dstq+dsq*2]
+    add                  hq, 2
+    jl .w4_loop
+    RET
+ALIGN function_align
+.w8:
+    vbroadcasti128       m4, [blend_shuf]
+    shufpd               m4, m4, 0x03
+.w8_loop:
+    vpbroadcastq         m1, [dstq+dsq*0]
+    movq                xm0, [dstq+dsq*1]
+    vpblendd             m0, m1, 0x30
+    vpbroadcastd         m3, [maskq+hq*2]
+    movq                xm1, [tmpq+8*1]
+    vinserti128          m1, [tmpq+8*0], 1
+    add                tmpq, 8*2
+    pshufb               m3, m4
+    punpcklbw            m0, m1
+    pmaddubsw            m0, m3
+    pmulhrsw             m0, m5
+    vextracti128        xm1, m0, 1
+    packuswb            xm0, xm1
+    movhps     [dstq+dsq*0], xm0
+    movq       [dstq+dsq*1], xm0
+    lea                dstq, [dstq+dsq*2]
+    add                  hq, 2
+    jl .w8_loop
+    RET
+ALIGN function_align
+.w16:
+    vbroadcasti128       m4, [blend_shuf]
+    shufpd               m4, m4, 0x0c
+.w16_loop:
     mova                xm1, [dstq+dsq*0]
     vinserti128          m1, [dstq+dsq*1], 1
+    vpbroadcastd         m3, [maskq+hq*2]
     mova                 m2, [tmpq]
     add                tmpq, 16*2
+    pshufb               m3, m4
     punpcklbw            m0, m1, m2
     punpckhbw            m1, m2
     pmaddubsw            m0, m3
-    pmaddubsw            m1, m4
+    pmaddubsw            m1, m3
     pmulhrsw             m0, m5
     pmulhrsw             m1, m5
     packuswb             m0, m1
@@ -3580,60 +3626,17 @@
     mova         [dstq+dsq*0], xm0
     vextracti128 [dstq+dsq*1], m0, 1
     lea                dstq, [dstq+dsq*2]
-    sub                  hd, 2
-    jg .w16_s0_loop
+    add                  hq, 2
+    jl .w16_loop
     RET
 ALIGN function_align
-.w32:
-    mov                  wd, 32
-    jmp .w32_start
-.w64:
-    mov                  wd, 64
-    jmp .w32_start
-.w128:
-    mov                  wd, 128
-.w32_start:
-    WIN64_SPILL_XMM       7
-    cmp                 msq, 1
-    jb .w32_s0
-    je .w32_s1
-    sub                 dsq, wq
-.w32_s32:
-    mov                 r6d, wd
-.w32_s32_loop:
-    mova                 m0, [maskq]
+.w32: ; w32/w64/w128
+    sub                 dsq, r6
+.w32_loop0:
+    vpbroadcastw         m3, [maskq+hq*2]
+    mov                  wd, r6d
+.w32_loop:
     mova                 m1, [dstq]
-    psubb                m3, m4, m0
-    punpcklbw            m2, m3, m0
-    punpckhbw            m3, m0
-    mova                 m6, [tmpq]
-    add               maskq, 32
-    add                tmpq, 32
-    punpcklbw            m0, m1, m6
-    punpckhbw            m1, m6
-    pmaddubsw            m0, m2
-    pmaddubsw            m1, m3
-    pmulhrsw             m0, m5
-    pmulhrsw             m1, m5
-    packuswb             m0, m1
-    mova             [dstq], m0
-    add                dstq, 32
-    sub                 r6d, 32
-    jg .w32_s32_loop
-    add                dstq, dsq
-    dec                  hd
-    jg .w32_s32
-    RET
-.w32_s1:
-    sub                 dsq, wq
-.w32_s1_loop0:
-    vpbroadcastb         m0, [maskq]
-    mov                 r6d, wd
-    inc               maskq
-    psubb                m3, m4, m0
-    punpcklbw            m3, m0
-.w32_s1_loop:
-    mova                 m1, [dstq]
     mova                 m2, [tmpq]
     add                tmpq, 32
     punpcklbw            m0, m1, m2
@@ -3645,49 +3648,11 @@
     packuswb             m0, m1
     mova             [dstq], m0
     add                dstq, 32
-    sub                 r6d, 32
-    jg .w32_s1_loop
+    sub                  wd, 32
+    jg .w32_loop
     add                dstq, dsq
-    dec                  hd
-    jg .w32_s1_loop0
-    RET
-.w32_s0:
-%if WIN64
-    PUSH                 r7
-    PUSH                 r8
-    %define regs_used 9
-%endif
-    lea                 r6d, [hq+wq*8-256]
-    mov                  r7, dstq
-    mov                  r8, tmpq
-.w32_s0_loop0:
-    mova                 m0, [maskq]
-    add               maskq, 32
-    psubb                m3, m4, m0
-    punpcklbw            m2, m3, m0
-    punpckhbw            m3, m0
-.w32_s0_loop:
-    mova                 m1, [dstq]
-    mova                 m6, [tmpq]
-    add                tmpq, wq
-    punpcklbw            m0, m1, m6
-    punpckhbw            m1, m6
-    pmaddubsw            m0, m2
-    pmaddubsw            m1, m3
-    pmulhrsw             m0, m5
-    pmulhrsw             m1, m5
-    packuswb             m0, m1
-    mova             [dstq], m0
-    add                dstq, dsq
-    dec                  hd
-    jg .w32_s0_loop
-    add                  r7, 32
-    add                  r8, 32
-    mov                dstq, r7
-    mov                tmpq, r8
-    mov                  hb, r6b
-    sub                 r6d, 256
-    jg .w32_s0_loop0
+    inc                  hq
+    jl .w32_loop0
     RET
 
 cglobal emu_edge, 10, 13, 1, bw, bh, iw, ih, x, y, dst, dstride, src, sstride, \
--- a/src/x86/mc_init_tmpl.c
+++ b/src/x86/mc_init_tmpl.c
@@ -55,6 +55,8 @@
 decl_mask_fn(dav1d_mask_avx2);
 decl_w_mask_fn(dav1d_w_mask_420_avx2);
 decl_blend_fn(dav1d_blend_avx2);
+decl_blend_dir_fn(dav1d_blend_v_avx2);
+decl_blend_dir_fn(dav1d_blend_h_avx2);
 
 decl_warp8x8_fn(dav1d_warp_affine_8x8_avx2);
 decl_warp8x8t_fn(dav1d_warp_affine_8x8t_avx2);
@@ -98,6 +100,8 @@
     c->mask = dav1d_mask_avx2;
     c->w_mask[2] = dav1d_w_mask_420_avx2;
     c->blend = dav1d_blend_avx2;
+    c->blend_v = dav1d_blend_v_avx2;
+    c->blend_h = dav1d_blend_h_avx2;
 
     c->warp8x8  = dav1d_warp_affine_8x8_avx2;
     c->warp8x8t = dav1d_warp_affine_8x8t_avx2;
--- a/tests/checkasm/mc.c
+++ b/tests/checkasm/mc.c
@@ -237,40 +237,95 @@
 }
 
 static void check_blend(Dav1dMCDSPContext *const c) {
-    ALIGN_STK_32(pixel, tmp, 128 * 32,);
-    ALIGN_STK_32(pixel, c_dst, 128 * 32,);
-    ALIGN_STK_32(pixel, a_dst, 128 * 32,);
-    ALIGN_STK_32(uint8_t, mask, 128 * 32,);
+    ALIGN_STK_32(pixel, tmp, 32 * 32,);
+    ALIGN_STK_32(pixel, c_dst, 32 * 32,);
+    ALIGN_STK_32(pixel, a_dst, 32 * 32,);
+    ALIGN_STK_32(uint8_t, mask, 32 * 32,);
 
-    for (int i = 0; i < 128 * 32; i++) {
+    for (int i = 0; i < 32 * 32; i++) {
         tmp[i] = rand() & ((1 << BITDEPTH) - 1);
         mask[i] = rand() % 65;
     }
 
     declare_func(void, pixel *dst, ptrdiff_t dst_stride, const pixel *tmp,
-                 int w, int h, const uint8_t *mask, ptrdiff_t mstride);
+                 int w, int h, const uint8_t *mask);
 
-    for (int w = 2; w <= 128; w <<= 1) {
+    for (int w = 4; w <= 32; w <<= 1) {
         const ptrdiff_t dst_stride = w * sizeof(pixel);
-        const int h_min = (w == 128) ? 4 : 2;
-        const int h_max = (w > 32) ? 32 : (w == 2) ? 64 : 128;
-        for (int ms = 0; ms <= w; ms += ms ? w - 1 : 1)
-            if (check_func(c->blend, "blend_w%d_ms%d_%dbpc", w, ms, BITDEPTH))
-                for (int h = h_min; h <= h_max; h <<= 1) {
-                    for (int i = 0; i < w * h; i++)
-                        c_dst[i] = a_dst[i] = rand() & ((1 << BITDEPTH) - 1);
+        if (check_func(c->blend, "blend_w%d_%dbpc", w, BITDEPTH))
+            for (int h = imax(w / 2, 4); h <= imin(w * 2, 32); h <<= 1) {
+                for (int i = 0; i < w * h; i++)
+                    c_dst[i] = a_dst[i] = rand() & ((1 << BITDEPTH) - 1);
 
-                    call_ref(c_dst, dst_stride, tmp, w, h, mask, ms);
-                    call_new(a_dst, dst_stride, tmp, w, h, mask, ms);
-                    if (memcmp(c_dst, a_dst, w * h * sizeof(*c_dst)))
-                        fail();
+                call_ref(c_dst, dst_stride, tmp, w, h, mask);
+                call_new(a_dst, dst_stride, tmp, w, h, mask);
+                if (memcmp(c_dst, a_dst, w * h * sizeof(*c_dst)))
+                    fail();
 
-                    bench_new(a_dst, dst_stride, tmp, w, h, mask, ms);
-                }
+                bench_new(a_dst, dst_stride, tmp, w, h, mask);
+            }
     }
     report("blend");
 }
 
+static void check_blend_v(Dav1dMCDSPContext *const c) {
+    ALIGN_STK_32(pixel, tmp,   32 * 128,);
+    ALIGN_STK_32(pixel, c_dst, 32 * 128,);
+    ALIGN_STK_32(pixel, a_dst, 32 * 128,);
+
+    for (int i = 0; i < 32 * 128; i++)
+        tmp[i] = rand() & ((1 << BITDEPTH) - 1);
+
+    declare_func(void, pixel *dst, ptrdiff_t dst_stride, const pixel *tmp,
+                 int w, int h);
+
+    for (int w = 2; w <= 32; w <<= 1) {
+        const ptrdiff_t dst_stride = w * sizeof(pixel);
+        if (check_func(c->blend_v, "blend_v_w%d_%dbpc", w, BITDEPTH))
+            for (int h = 2; h <= (w == 2 ? 64 : 128); h <<= 1) {
+                for (int i = 0; i < w * h; i++)
+                    c_dst[i] = a_dst[i] = rand() & ((1 << BITDEPTH) - 1);
+
+                call_ref(c_dst, dst_stride, tmp, w, h);
+                call_new(a_dst, dst_stride, tmp, w, h);
+                if (memcmp(c_dst, a_dst, w * h * sizeof(*c_dst)))
+                    fail();
+
+                bench_new(a_dst, dst_stride, tmp, w, h);
+            }
+    }
+    report("blend_v");
+}
+
+static void check_blend_h(Dav1dMCDSPContext *const c) {
+    ALIGN_STK_32(pixel, tmp,   128 * 32,);
+    ALIGN_STK_32(pixel, c_dst, 128 * 32,);
+    ALIGN_STK_32(pixel, a_dst, 128 * 32,);
+
+    for (int i = 0; i < 128 * 32; i++)
+        tmp[i] = rand() & ((1 << BITDEPTH) - 1);
+
+    declare_func(void, pixel *dst, ptrdiff_t dst_stride, const pixel *tmp,
+                 int w, int h);
+
+    for (int w = 2; w <= 128; w <<= 1) {
+        const ptrdiff_t dst_stride = w * sizeof(pixel);
+        if (check_func(c->blend_h, "blend_h_w%d_%dbpc", w, BITDEPTH))
+            for (int h = (w == 128 ? 4 : 2); h <= 32; h <<= 1) {
+                for (int i = 0; i < w * h; i++)
+                    c_dst[i] = a_dst[i] = rand() & ((1 << BITDEPTH) - 1);
+
+                call_ref(c_dst, dst_stride, tmp, w, h);
+                call_new(a_dst, dst_stride, tmp, w, h);
+                if (memcmp(c_dst, a_dst, w * h * sizeof(*c_dst)))
+                    fail();
+
+                bench_new(a_dst, dst_stride, tmp, w, h);
+            }
+    }
+    report("blend_h");
+}
+
 static void check_warp8x8(Dav1dMCDSPContext *const c) {
     ALIGN_STK_32(pixel, src_buf, 15 * 15,);
     ALIGN_STK_32(pixel, c_dst,    8 *  8,);
@@ -430,6 +485,8 @@
     check_mask(&c);
     check_w_mask(&c);
     check_blend(&c);
+    check_blend_v(&c);
+    check_blend_h(&c);
     check_warp8x8(&c);
     check_warp8x8t(&c);
     check_emuedge(&c);