shithub: dav1d

Download patch

ref: d085424c9225906e375788ded32f77323ae31f03
parent: 22080aa30cfed267f8c13c293db1dcc34012ecef
author: Henrik Gramner <[email protected]>
date: Thu Feb 20 10:25:40 EST 2020

x86: Add mc avg/w_avg/mask AVX-512 (Ice Lake) asm

--- a/src/x86/mc.asm
+++ b/src/x86/mc.asm
@@ -46,6 +46,7 @@
             db 56,  8, 57,  7, 58,  6, 59,  5, 60,  4, 60,  4, 61,  3, 62,  2
             db 64,  0, 64,  0, 64,  0, 64,  0, 64,  0, 64,  0, 64,  0, 64,  0
 
+bidir_sctr_w4:  dd  0,  1,  8,  9,  2,  3, 10, 11,  4,  5, 12, 13,  6,  7, 14, 15
 bilin_h_perm16: db  1,  0,  2,  1,  3,  2,  4,  3,  5,  4,  6,  5,  7,  6,  8,  7
                 db  9,  8, 10,  9, 11, 10, 12, 11, 13, 12, 14, 13, 15, 14, 16, 15
                 db 33, 32, 34, 33, 35, 34, 36, 35, 37, 36, 38, 37, 39, 38, 40, 39
@@ -128,29 +129,9 @@
 pd_32768: dd 32768
 
 cextern mc_subpel_filters
+cextern mc_warp_filter
 %define subpel_filters (mangle(private_prefix %+ _mc_subpel_filters)-8)
 
-%macro BIDIR_JMP_TABLE 1-*
-    %xdefine %1_table (%%table - 2*%2)
-    %xdefine %%base %1_table
-    %xdefine %%prefix mangle(private_prefix %+ _%1)
-    %%table:
-    %rep %0 - 1
-        dd %%prefix %+ .w%2 - %%base
-        %rotate 1
-    %endrep
-%endmacro
-
-BIDIR_JMP_TABLE avg_avx2,        4, 8, 16, 32, 64, 128
-BIDIR_JMP_TABLE w_avg_avx2,      4, 8, 16, 32, 64, 128
-BIDIR_JMP_TABLE mask_avx2,       4, 8, 16, 32, 64, 128
-BIDIR_JMP_TABLE w_mask_420_avx2, 4, 8, 16, 32, 64, 128
-BIDIR_JMP_TABLE w_mask_422_avx2, 4, 8, 16, 32, 64, 128
-BIDIR_JMP_TABLE w_mask_444_avx2, 4, 8, 16, 32, 64, 128
-BIDIR_JMP_TABLE blend_avx2,      4, 8, 16, 32
-BIDIR_JMP_TABLE blend_v_avx2, 2, 4, 8, 16, 32
-BIDIR_JMP_TABLE blend_h_avx2, 2, 4, 8, 16, 32, 32, 32
-
 %macro BASE_JMP_TABLE 3-*
     %xdefine %1_%2_table (%%table - %3)
     %xdefine %%base %1_%2
@@ -161,14 +142,6 @@
     %endrep
 %endmacro
 
-%xdefine put_avx2 mangle(private_prefix %+ _put_bilin_avx2.put)
-%xdefine prep_avx2 mangle(private_prefix %+ _prep_bilin_avx2.prep)
-%xdefine prep_avx512icl mangle(private_prefix %+ _prep_bilin_avx512icl.prep)
-
-BASE_JMP_TABLE put,  avx2,     2, 4, 8, 16, 32, 64, 128
-BASE_JMP_TABLE prep, avx2,        4, 8, 16, 32, 64, 128
-BASE_JMP_TABLE prep, avx512icl,   4, 8, 16, 32, 64, 128
-
 %macro HV_JMP_TABLE 5-*
     %xdefine %%prefix mangle(private_prefix %+ _%1_%2_%3)
     %xdefine %%base %1_%3
@@ -201,17 +174,46 @@
     %endif
 %endmacro
 
-HV_JMP_TABLE put,  8tap,  avx2,      3, 2, 4, 8, 16, 32, 64, 128
-HV_JMP_TABLE prep, 8tap,  avx2,      1,    4, 8, 16, 32, 64, 128
-HV_JMP_TABLE prep, 8tap,  avx512icl, 7,    4, 8, 16, 32, 64, 128
-HV_JMP_TABLE put,  bilin, avx2,      7, 2, 4, 8, 16, 32, 64, 128
-HV_JMP_TABLE prep, bilin, avx2,      7,    4, 8, 16, 32, 64, 128
-HV_JMP_TABLE prep, bilin, avx512icl, 7,    4, 8, 16, 32, 64, 128
+%macro BIDIR_JMP_TABLE 1-*
+    %xdefine %1_table (%%table - 2*%2)
+    %xdefine %%base %1_table
+    %xdefine %%prefix mangle(private_prefix %+ _%1)
+    %%table:
+    %rep %0 - 1
+        dd %%prefix %+ .w%2 - %%base
+        %rotate 1
+    %endrep
+%endmacro
 
+%xdefine put_avx2 mangle(private_prefix %+ _put_bilin_avx2.put)
+%xdefine prep_avx2 mangle(private_prefix %+ _prep_bilin_avx2.prep)
+%xdefine prep_avx512icl mangle(private_prefix %+ _prep_bilin_avx512icl.prep)
+
 %define table_offset(type, fn) type %+ fn %+ SUFFIX %+ _table - type %+ SUFFIX
 
-cextern mc_warp_filter
+BASE_JMP_TABLE put,  avx2,         2, 4, 8, 16, 32, 64, 128
+BASE_JMP_TABLE prep, avx2,            4, 8, 16, 32, 64, 128
+HV_JMP_TABLE put,  bilin, avx2, 7, 2, 4, 8, 16, 32, 64, 128
+HV_JMP_TABLE prep, bilin, avx2, 7,    4, 8, 16, 32, 64, 128
+HV_JMP_TABLE put,  8tap,  avx2, 3, 2, 4, 8, 16, 32, 64, 128
+HV_JMP_TABLE prep, 8tap,  avx2, 1,    4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE avg_avx2,             4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE w_avg_avx2,           4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE mask_avx2,            4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE w_mask_420_avx2,      4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE w_mask_422_avx2,      4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE w_mask_444_avx2,      4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE blend_avx2,           4, 8, 16, 32
+BIDIR_JMP_TABLE blend_v_avx2,      2, 4, 8, 16, 32
+BIDIR_JMP_TABLE blend_h_avx2,      2, 4, 8, 16, 32, 32, 32
 
+BASE_JMP_TABLE prep, avx512icl,            4, 8, 16, 32, 64, 128
+HV_JMP_TABLE prep, bilin, avx512icl, 7,    4, 8, 16, 32, 64, 128
+HV_JMP_TABLE prep, 8tap,  avx512icl, 7,    4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE avg_avx512icl,             4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE w_avg_avx512icl,           4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE mask_avx512icl,            4, 8, 16, 32, 64, 128
+
 SECTION .text
 
 INIT_XMM avx2
@@ -4017,8 +4019,122 @@
     paddd                m0, m15 ; rounded 14-bit result in upper 16 bits of dword
     ret
 
+%macro WRAP_YMM 1+
+    INIT_YMM cpuname
+    %1
+    INIT_ZMM cpuname
+%endmacro
+
 %macro BIDIR_FN 1 ; op
+%if mmsize == 64
+    lea            stride3q, [strideq*3]
+    jmp                  wq
+.w4:
+    cmp                  hd, 8
+    jg .w4_h16
+    WRAP_YMM %1           0
+    vextracti32x4      xmm1, ym0, 1
+    movd   [dstq          ], xm0
+    pextrd [dstq+strideq*1], xm0, 1
+    movd   [dstq+strideq*2], xmm1
+    pextrd [dstq+stride3q ], xmm1, 1
+    jl .w4_ret
+    lea                dstq, [dstq+strideq*4]
+    pextrd [dstq          ], xm0, 2
+    pextrd [dstq+strideq*1], xm0, 3
+    pextrd [dstq+strideq*2], xmm1, 2
+    pextrd [dstq+stride3q ], xmm1, 3
+.w4_ret:
+    RET
+.w4_h16:
+    vpbroadcastd         m7, strided
+    pmulld               m7, [bidir_sctr_w4]
     %1                    0
+    kxnorw               k1, k1, k1
+    vpscatterdd [dstq+m7]{k1}, m0
+    RET
+.w8:
+    cmp                  hd, 4
+    jne .w8_h8
+    WRAP_YMM %1           0
+    vextracti128       xmm1, ym0, 1
+    movq   [dstq          ], xm0
+    movq   [dstq+strideq*1], xmm1
+    movhps [dstq+strideq*2], xm0
+    movhps [dstq+stride3q ], xmm1
+    RET
+.w8_loop:
+    %1_INC_PTR            2
+    lea                dstq, [dstq+strideq*4]
+.w8_h8:
+    %1                    0
+    vextracti32x4      xmm1, ym0, 1
+    vextracti32x4      xmm2, m0, 2
+    vextracti32x4      xmm3, m0, 3
+    movq   [dstq          ], xm0
+    movq   [dstq+strideq*1], xmm1
+    movq   [dstq+strideq*2], xmm2
+    movq   [dstq+stride3q ], xmm3
+    lea                dstq, [dstq+strideq*4]
+    movhps [dstq          ], xm0
+    movhps [dstq+strideq*1], xmm1
+    movhps [dstq+strideq*2], xmm2
+    movhps [dstq+stride3q ], xmm3
+    sub                  hd, 8
+    jg .w8_loop
+    RET
+.w16_loop:
+    %1_INC_PTR            2
+    lea                dstq, [dstq+strideq*4]
+.w16:
+    %1                    0
+    vpermq               m0, m0, q3120
+    mova          [dstq          ], xm0
+    vextracti32x4 [dstq+strideq*1], m0, 2
+    vextracti32x4 [dstq+strideq*2], ym0, 1
+    vextracti32x4 [dstq+stride3q ], m0, 3
+    sub                  hd, 4
+    jg .w16_loop
+    RET
+.w32:
+    pmovzxbq             m7, [warp_8x8_shufA]
+.w32_loop:
+    %1                    0
+    %1_INC_PTR            2
+    vpermq               m0, m7, m0
+    mova          [dstq+strideq*0], ym0
+    vextracti32x8 [dstq+strideq*1], m0, 1
+    lea                dstq, [dstq+strideq*2]
+    sub                  hd, 2
+    jg .w32_loop
+    RET
+.w64:
+    pmovzxbq             m7, [warp_8x8_shufA]
+.w64_loop:
+    %1                    0
+    %1_INC_PTR            2
+    vpermq               m0, m7, m0
+    mova             [dstq], m0
+    add                dstq, strideq
+    dec                  hd
+    jg .w64_loop
+    RET
+.w128:
+    pmovzxbq             m7, [warp_8x8_shufA]
+.w128_loop:
+    %1                    0
+    vpermq               m6, m7, m0
+    %1                    2
+    mova        [dstq+64*0], m6
+    %1_INC_PTR            4
+    vpermq               m6, m7, m0
+    mova        [dstq+64*1], m6
+    add                dstq, strideq
+    dec                  hd
+    jg .w128_loop
+    RET
+%else
+    %1                    0
     lea            stride3q, [strideq*3]
     jmp                  wq
 .w4:
@@ -4042,7 +4158,7 @@
     movd   [dstq          ], xm0
     pextrd [dstq+strideq*1], xm0, 1
     movd   [dstq+strideq*2], xm1
-    pextrd [dstq+stride3q], xm1, 1
+    pextrd [dstq+stride3q ], xm1, 1
     lea                dstq, [dstq+strideq*4]
     pextrd [dstq          ], xm0, 2
     pextrd [dstq+strideq*1], xm0, 3
@@ -4084,7 +4200,7 @@
     lea                dstq, [dstq+strideq*2]
 .w32:
     vpermq               m0, m0, q3120
-    mova             [dstq], m0
+    mova   [dstq+strideq*0], m0
     %1                    2
     vpermq               m0, m0, q3120
     mova   [dstq+strideq*1], m0
@@ -4123,6 +4239,7 @@
     dec                  hd
     jg .w128_loop
     RET
+%endif
 %endmacro
 
 %macro AVG 1 ; src_offset
@@ -4140,14 +4257,17 @@
     add               tmp2q, %1*mmsize
 %endmacro
 
+%macro AVG_FN 0
 cglobal avg, 4, 7, 3, dst, stride, tmp1, tmp2, w, h, stride3
-    lea                  r6, [avg_avx2_table]
+%define base r6-avg %+ SUFFIX %+ _table
+    lea                  r6, [avg %+ SUFFIX %+ _table]
     tzcnt                wd, wm
     movifnidn            hd, hm
     movsxd               wq, dword [r6+wq*4]
-    vpbroadcastd         m2, [pw_1024+r6-avg_avx2_table]
+    vpbroadcastd         m2, [base+pw_1024]
     add                  wq, r6
     BIDIR_FN            AVG
+%endmacro
 
 %macro W_AVG 1 ; src_offset
     ; (a * weight + b * (16 - weight) + 128) >> 8
@@ -4169,13 +4289,15 @@
 
 %define W_AVG_INC_PTR AVG_INC_PTR
 
+%macro W_AVG_FN 0
 cglobal w_avg, 4, 7, 6, dst, stride, tmp1, tmp2, w, h, stride3
-    lea                  r6, [w_avg_avx2_table]
+%define base r6-w_avg %+ SUFFIX %+ _table
+    lea                  r6, [w_avg %+ SUFFIX %+ _table]
     tzcnt                wd, wm
     movifnidn            hd, hm
     vpbroadcastw         m4, r6m ; weight
     movsxd               wq, dword [r6+wq*4]
-    vpbroadcastd         m5, [pw_2048+r6-w_avg_avx2_table]
+    vpbroadcastd         m5, [base+pw_2048]
     psllw                m4, 12 ; (weight-16) << 12 when interpreted as signed
     add                  wq, r6
     cmp           dword r6m, 7
@@ -4187,12 +4309,17 @@
     mov               tmp2q, r6
 .weight_gt7:
     BIDIR_FN          W_AVG
+%endmacro
 
 %macro MASK 1 ; src_offset
     ; (a * m + b * (64 - m) + 512) >> 10
     ; = ((a - b) * m + (b << 6) + 512) >> 10
     ; = ((((b - a) * (-m << 10)) >> 16) + b + 8) >> 4
-    vpermq               m3,     [maskq+(%1+0)*(mmsize/2)], q3120
+%if mmsize == 64
+    vpermq               m3, m8, [maskq+%1*32]
+%else
+    vpermq               m3,     [maskq+%1*16], q3120
+%endif
     mova                 m0,     [tmp2q+(%1+0)*mmsize]
     psubw                m1, m0, [tmp1q+(%1+0)*mmsize]
     psubb                m3, m4, m3
@@ -4214,20 +4341,26 @@
 
 %macro MASK_INC_PTR 1
     add               maskq, %1*mmsize/2
-    add               tmp1q, %1*mmsize
     add               tmp2q, %1*mmsize
+    add               tmp1q, %1*mmsize
 %endmacro
 
+%macro MASK_FN 0
 cglobal mask, 4, 8, 6, dst, stride, tmp1, tmp2, w, h, mask, stride3
-    lea                  r7, [mask_avx2_table]
+%define base r7-mask %+ SUFFIX %+ _table
+    lea                  r7, [mask %+ SUFFIX %+ _table]
     tzcnt                wd, wm
     movifnidn            hd, hm
     mov               maskq, maskmp
     movsxd               wq, dword [r7+wq*4]
     pxor                 m4, m4
-    vpbroadcastd         m5, [pw_2048+r7-mask_avx2_table]
+%if mmsize == 64
+    mova                 m8, [base+bilin_v_perm64]
+%endif
+    vpbroadcastd         m5, [base+pw_2048]
     add                  wq, r7
     BIDIR_FN           MASK
+%endmacro MASK_FN
 
 %macro W_MASK 2-3 0 ; src_offset, mask_out, 4:4:4
     mova                 m0, [tmp1q+(%1+0)*mmsize]
@@ -5324,9 +5457,15 @@
 INIT_YMM avx2
 PREP_BILIN
 PREP_8TAP
+AVG_FN
+W_AVG_FN
+MASK_FN
 
 INIT_ZMM avx512icl
 PREP_BILIN
 PREP_8TAP
+AVG_FN
+W_AVG_FN
+MASK_FN
 
 %endif ; ARCH_X86_64
--- a/src/x86/mc_init_tmpl.c
+++ b/src/x86/mc_init_tmpl.c
@@ -80,10 +80,13 @@
 decl_mct_fn(dav1d_prep_bilin_avx2);
 decl_mct_fn(dav1d_prep_bilin_ssse3);
 
+decl_avg_fn(dav1d_avg_avx512icl);
 decl_avg_fn(dav1d_avg_avx2);
 decl_avg_fn(dav1d_avg_ssse3);
+decl_w_avg_fn(dav1d_w_avg_avx512icl);
 decl_w_avg_fn(dav1d_w_avg_avx2);
 decl_w_avg_fn(dav1d_w_avg_ssse3);
+decl_mask_fn(dav1d_mask_avx512icl);
 decl_mask_fn(dav1d_mask_avx2);
 decl_mask_fn(dav1d_mask_ssse3);
 decl_w_mask_fn(dav1d_w_mask_420_avx2);
@@ -172,10 +175,11 @@
     c->warp8x8t = dav1d_warp_affine_8x8t_sse4;
 #endif
 
+#if ARCH_X86_64
     if (!(flags & DAV1D_X86_CPU_FLAG_AVX2))
         return;
 
-#if BITDEPTH == 8 && ARCH_X86_64
+#if BITDEPTH == 8
     init_mc_fn (FILTER_2D_8TAP_REGULAR,        8tap_regular,        avx2);
     init_mc_fn (FILTER_2D_8TAP_REGULAR_SMOOTH, 8tap_regular_smooth, avx2);
     init_mc_fn (FILTER_2D_8TAP_REGULAR_SHARP,  8tap_regular_sharp,  avx2);
@@ -217,7 +221,7 @@
     if (!(flags & DAV1D_X86_CPU_FLAG_AVX512ICL))
         return;
 
-#if BITDEPTH == 8 && ARCH_X86_64
+#if BITDEPTH == 8
     init_mct_fn(FILTER_2D_8TAP_REGULAR,        8tap_regular,        avx512icl);
     init_mct_fn(FILTER_2D_8TAP_REGULAR_SMOOTH, 8tap_regular_smooth, avx512icl);
     init_mct_fn(FILTER_2D_8TAP_REGULAR_SHARP,  8tap_regular_sharp,  avx512icl);
@@ -228,5 +232,10 @@
     init_mct_fn(FILTER_2D_8TAP_SHARP_SMOOTH,   8tap_sharp_smooth,   avx512icl);
     init_mct_fn(FILTER_2D_8TAP_SHARP,          8tap_sharp,          avx512icl);
     init_mct_fn(FILTER_2D_BILINEAR,            bilin,               avx512icl);
+
+    c->avg = dav1d_avg_avx512icl;
+    c->w_avg = dav1d_w_avg_avx512icl;
+    c->mask = dav1d_mask_avx512icl;
+#endif
 #endif
 }