ref: d085424c9225906e375788ded32f77323ae31f03
parent: 22080aa30cfed267f8c13c293db1dcc34012ecef
author: Henrik Gramner <[email protected]>
date: Thu Feb 20 10:25:40 EST 2020
x86: Add mc avg/w_avg/mask AVX-512 (Ice Lake) asm
--- a/src/x86/mc.asm
+++ b/src/x86/mc.asm
@@ -46,6 +46,7 @@
db 56, 8, 57, 7, 58, 6, 59, 5, 60, 4, 60, 4, 61, 3, 62, 2
db 64, 0, 64, 0, 64, 0, 64, 0, 64, 0, 64, 0, 64, 0, 64, 0
+bidir_sctr_w4: dd 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15
bilin_h_perm16: db 1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6, 8, 7
db 9, 8, 10, 9, 11, 10, 12, 11, 13, 12, 14, 13, 15, 14, 16, 15
db 33, 32, 34, 33, 35, 34, 36, 35, 37, 36, 38, 37, 39, 38, 40, 39
@@ -128,29 +129,9 @@
pd_32768: dd 32768
cextern mc_subpel_filters
+cextern mc_warp_filter
%define subpel_filters (mangle(private_prefix %+ _mc_subpel_filters)-8)
-%macro BIDIR_JMP_TABLE 1-*
- %xdefine %1_table (%%table - 2*%2)
- %xdefine %%base %1_table
- %xdefine %%prefix mangle(private_prefix %+ _%1)
- %%table:
- %rep %0 - 1
- dd %%prefix %+ .w%2 - %%base
- %rotate 1
- %endrep
-%endmacro
-
-BIDIR_JMP_TABLE avg_avx2, 4, 8, 16, 32, 64, 128
-BIDIR_JMP_TABLE w_avg_avx2, 4, 8, 16, 32, 64, 128
-BIDIR_JMP_TABLE mask_avx2, 4, 8, 16, 32, 64, 128
-BIDIR_JMP_TABLE w_mask_420_avx2, 4, 8, 16, 32, 64, 128
-BIDIR_JMP_TABLE w_mask_422_avx2, 4, 8, 16, 32, 64, 128
-BIDIR_JMP_TABLE w_mask_444_avx2, 4, 8, 16, 32, 64, 128
-BIDIR_JMP_TABLE blend_avx2, 4, 8, 16, 32
-BIDIR_JMP_TABLE blend_v_avx2, 2, 4, 8, 16, 32
-BIDIR_JMP_TABLE blend_h_avx2, 2, 4, 8, 16, 32, 32, 32
-
%macro BASE_JMP_TABLE 3-*
%xdefine %1_%2_table (%%table - %3)
%xdefine %%base %1_%2
@@ -161,14 +142,6 @@
%endrep
%endmacro
-%xdefine put_avx2 mangle(private_prefix %+ _put_bilin_avx2.put)
-%xdefine prep_avx2 mangle(private_prefix %+ _prep_bilin_avx2.prep)
-%xdefine prep_avx512icl mangle(private_prefix %+ _prep_bilin_avx512icl.prep)
-
-BASE_JMP_TABLE put, avx2, 2, 4, 8, 16, 32, 64, 128
-BASE_JMP_TABLE prep, avx2, 4, 8, 16, 32, 64, 128
-BASE_JMP_TABLE prep, avx512icl, 4, 8, 16, 32, 64, 128
-
%macro HV_JMP_TABLE 5-*
%xdefine %%prefix mangle(private_prefix %+ _%1_%2_%3)
%xdefine %%base %1_%3
@@ -201,17 +174,46 @@
%endif
%endmacro
-HV_JMP_TABLE put, 8tap, avx2, 3, 2, 4, 8, 16, 32, 64, 128
-HV_JMP_TABLE prep, 8tap, avx2, 1, 4, 8, 16, 32, 64, 128
-HV_JMP_TABLE prep, 8tap, avx512icl, 7, 4, 8, 16, 32, 64, 128
-HV_JMP_TABLE put, bilin, avx2, 7, 2, 4, 8, 16, 32, 64, 128
-HV_JMP_TABLE prep, bilin, avx2, 7, 4, 8, 16, 32, 64, 128
-HV_JMP_TABLE prep, bilin, avx512icl, 7, 4, 8, 16, 32, 64, 128
+%macro BIDIR_JMP_TABLE 1-*
+ %xdefine %1_table (%%table - 2*%2)
+ %xdefine %%base %1_table
+ %xdefine %%prefix mangle(private_prefix %+ _%1)
+ %%table:
+ %rep %0 - 1
+ dd %%prefix %+ .w%2 - %%base
+ %rotate 1
+ %endrep
+%endmacro
+%xdefine put_avx2 mangle(private_prefix %+ _put_bilin_avx2.put)
+%xdefine prep_avx2 mangle(private_prefix %+ _prep_bilin_avx2.prep)
+%xdefine prep_avx512icl mangle(private_prefix %+ _prep_bilin_avx512icl.prep)
+
%define table_offset(type, fn) type %+ fn %+ SUFFIX %+ _table - type %+ SUFFIX
-cextern mc_warp_filter
+BASE_JMP_TABLE put, avx2, 2, 4, 8, 16, 32, 64, 128
+BASE_JMP_TABLE prep, avx2, 4, 8, 16, 32, 64, 128
+HV_JMP_TABLE put, bilin, avx2, 7, 2, 4, 8, 16, 32, 64, 128
+HV_JMP_TABLE prep, bilin, avx2, 7, 4, 8, 16, 32, 64, 128
+HV_JMP_TABLE put, 8tap, avx2, 3, 2, 4, 8, 16, 32, 64, 128
+HV_JMP_TABLE prep, 8tap, avx2, 1, 4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE avg_avx2, 4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE w_avg_avx2, 4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE mask_avx2, 4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE w_mask_420_avx2, 4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE w_mask_422_avx2, 4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE w_mask_444_avx2, 4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE blend_avx2, 4, 8, 16, 32
+BIDIR_JMP_TABLE blend_v_avx2, 2, 4, 8, 16, 32
+BIDIR_JMP_TABLE blend_h_avx2, 2, 4, 8, 16, 32, 32, 32
+BASE_JMP_TABLE prep, avx512icl, 4, 8, 16, 32, 64, 128
+HV_JMP_TABLE prep, bilin, avx512icl, 7, 4, 8, 16, 32, 64, 128
+HV_JMP_TABLE prep, 8tap, avx512icl, 7, 4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE avg_avx512icl, 4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE w_avg_avx512icl, 4, 8, 16, 32, 64, 128
+BIDIR_JMP_TABLE mask_avx512icl, 4, 8, 16, 32, 64, 128
+
SECTION .text
INIT_XMM avx2
@@ -4017,8 +4019,122 @@
paddd m0, m15 ; rounded 14-bit result in upper 16 bits of dword
ret
+%macro WRAP_YMM 1+
+ INIT_YMM cpuname
+ %1
+ INIT_ZMM cpuname
+%endmacro
+
%macro BIDIR_FN 1 ; op
+%if mmsize == 64
+ lea stride3q, [strideq*3]
+ jmp wq
+.w4:
+ cmp hd, 8
+ jg .w4_h16
+ WRAP_YMM %1 0
+ vextracti32x4 xmm1, ym0, 1
+ movd [dstq ], xm0
+ pextrd [dstq+strideq*1], xm0, 1
+ movd [dstq+strideq*2], xmm1
+ pextrd [dstq+stride3q ], xmm1, 1
+ jl .w4_ret
+ lea dstq, [dstq+strideq*4]
+ pextrd [dstq ], xm0, 2
+ pextrd [dstq+strideq*1], xm0, 3
+ pextrd [dstq+strideq*2], xmm1, 2
+ pextrd [dstq+stride3q ], xmm1, 3
+.w4_ret:
+ RET
+.w4_h16:
+ vpbroadcastd m7, strided
+ pmulld m7, [bidir_sctr_w4]
%1 0
+ kxnorw k1, k1, k1
+ vpscatterdd [dstq+m7]{k1}, m0
+ RET
+.w8:
+ cmp hd, 4
+ jne .w8_h8
+ WRAP_YMM %1 0
+ vextracti128 xmm1, ym0, 1
+ movq [dstq ], xm0
+ movq [dstq+strideq*1], xmm1
+ movhps [dstq+strideq*2], xm0
+ movhps [dstq+stride3q ], xmm1
+ RET
+.w8_loop:
+ %1_INC_PTR 2
+ lea dstq, [dstq+strideq*4]
+.w8_h8:
+ %1 0
+ vextracti32x4 xmm1, ym0, 1
+ vextracti32x4 xmm2, m0, 2
+ vextracti32x4 xmm3, m0, 3
+ movq [dstq ], xm0
+ movq [dstq+strideq*1], xmm1
+ movq [dstq+strideq*2], xmm2
+ movq [dstq+stride3q ], xmm3
+ lea dstq, [dstq+strideq*4]
+ movhps [dstq ], xm0
+ movhps [dstq+strideq*1], xmm1
+ movhps [dstq+strideq*2], xmm2
+ movhps [dstq+stride3q ], xmm3
+ sub hd, 8
+ jg .w8_loop
+ RET
+.w16_loop:
+ %1_INC_PTR 2
+ lea dstq, [dstq+strideq*4]
+.w16:
+ %1 0
+ vpermq m0, m0, q3120
+ mova [dstq ], xm0
+ vextracti32x4 [dstq+strideq*1], m0, 2
+ vextracti32x4 [dstq+strideq*2], ym0, 1
+ vextracti32x4 [dstq+stride3q ], m0, 3
+ sub hd, 4
+ jg .w16_loop
+ RET
+.w32:
+ pmovzxbq m7, [warp_8x8_shufA]
+.w32_loop:
+ %1 0
+ %1_INC_PTR 2
+ vpermq m0, m7, m0
+ mova [dstq+strideq*0], ym0
+ vextracti32x8 [dstq+strideq*1], m0, 1
+ lea dstq, [dstq+strideq*2]
+ sub hd, 2
+ jg .w32_loop
+ RET
+.w64:
+ pmovzxbq m7, [warp_8x8_shufA]
+.w64_loop:
+ %1 0
+ %1_INC_PTR 2
+ vpermq m0, m7, m0
+ mova [dstq], m0
+ add dstq, strideq
+ dec hd
+ jg .w64_loop
+ RET
+.w128:
+ pmovzxbq m7, [warp_8x8_shufA]
+.w128_loop:
+ %1 0
+ vpermq m6, m7, m0
+ %1 2
+ mova [dstq+64*0], m6
+ %1_INC_PTR 4
+ vpermq m6, m7, m0
+ mova [dstq+64*1], m6
+ add dstq, strideq
+ dec hd
+ jg .w128_loop
+ RET
+%else
+ %1 0
lea stride3q, [strideq*3]
jmp wq
.w4:
@@ -4042,7 +4158,7 @@
movd [dstq ], xm0
pextrd [dstq+strideq*1], xm0, 1
movd [dstq+strideq*2], xm1
- pextrd [dstq+stride3q], xm1, 1
+ pextrd [dstq+stride3q ], xm1, 1
lea dstq, [dstq+strideq*4]
pextrd [dstq ], xm0, 2
pextrd [dstq+strideq*1], xm0, 3
@@ -4084,7 +4200,7 @@
lea dstq, [dstq+strideq*2]
.w32:
vpermq m0, m0, q3120
- mova [dstq], m0
+ mova [dstq+strideq*0], m0
%1 2
vpermq m0, m0, q3120
mova [dstq+strideq*1], m0
@@ -4123,6 +4239,7 @@
dec hd
jg .w128_loop
RET
+%endif
%endmacro
%macro AVG 1 ; src_offset
@@ -4140,14 +4257,17 @@
add tmp2q, %1*mmsize
%endmacro
+%macro AVG_FN 0
cglobal avg, 4, 7, 3, dst, stride, tmp1, tmp2, w, h, stride3
- lea r6, [avg_avx2_table]
+%define base r6-avg %+ SUFFIX %+ _table
+ lea r6, [avg %+ SUFFIX %+ _table]
tzcnt wd, wm
movifnidn hd, hm
movsxd wq, dword [r6+wq*4]
- vpbroadcastd m2, [pw_1024+r6-avg_avx2_table]
+ vpbroadcastd m2, [base+pw_1024]
add wq, r6
BIDIR_FN AVG
+%endmacro
%macro W_AVG 1 ; src_offset
; (a * weight + b * (16 - weight) + 128) >> 8
@@ -4169,13 +4289,15 @@
%define W_AVG_INC_PTR AVG_INC_PTR
+%macro W_AVG_FN 0
cglobal w_avg, 4, 7, 6, dst, stride, tmp1, tmp2, w, h, stride3
- lea r6, [w_avg_avx2_table]
+%define base r6-w_avg %+ SUFFIX %+ _table
+ lea r6, [w_avg %+ SUFFIX %+ _table]
tzcnt wd, wm
movifnidn hd, hm
vpbroadcastw m4, r6m ; weight
movsxd wq, dword [r6+wq*4]
- vpbroadcastd m5, [pw_2048+r6-w_avg_avx2_table]
+ vpbroadcastd m5, [base+pw_2048]
psllw m4, 12 ; (weight-16) << 12 when interpreted as signed
add wq, r6
cmp dword r6m, 7
@@ -4187,12 +4309,17 @@
mov tmp2q, r6
.weight_gt7:
BIDIR_FN W_AVG
+%endmacro
%macro MASK 1 ; src_offset
; (a * m + b * (64 - m) + 512) >> 10
; = ((a - b) * m + (b << 6) + 512) >> 10
; = ((((b - a) * (-m << 10)) >> 16) + b + 8) >> 4
- vpermq m3, [maskq+(%1+0)*(mmsize/2)], q3120
+%if mmsize == 64
+ vpermq m3, m8, [maskq+%1*32]
+%else
+ vpermq m3, [maskq+%1*16], q3120
+%endif
mova m0, [tmp2q+(%1+0)*mmsize]
psubw m1, m0, [tmp1q+(%1+0)*mmsize]
psubb m3, m4, m3
@@ -4214,20 +4341,26 @@
%macro MASK_INC_PTR 1
add maskq, %1*mmsize/2
- add tmp1q, %1*mmsize
add tmp2q, %1*mmsize
+ add tmp1q, %1*mmsize
%endmacro
+%macro MASK_FN 0
cglobal mask, 4, 8, 6, dst, stride, tmp1, tmp2, w, h, mask, stride3
- lea r7, [mask_avx2_table]
+%define base r7-mask %+ SUFFIX %+ _table
+ lea r7, [mask %+ SUFFIX %+ _table]
tzcnt wd, wm
movifnidn hd, hm
mov maskq, maskmp
movsxd wq, dword [r7+wq*4]
pxor m4, m4
- vpbroadcastd m5, [pw_2048+r7-mask_avx2_table]
+%if mmsize == 64
+ mova m8, [base+bilin_v_perm64]
+%endif
+ vpbroadcastd m5, [base+pw_2048]
add wq, r7
BIDIR_FN MASK
+%endmacro MASK_FN
%macro W_MASK 2-3 0 ; src_offset, mask_out, 4:4:4
mova m0, [tmp1q+(%1+0)*mmsize]
@@ -5324,9 +5457,15 @@
INIT_YMM avx2
PREP_BILIN
PREP_8TAP
+AVG_FN
+W_AVG_FN
+MASK_FN
INIT_ZMM avx512icl
PREP_BILIN
PREP_8TAP
+AVG_FN
+W_AVG_FN
+MASK_FN
%endif ; ARCH_X86_64
--- a/src/x86/mc_init_tmpl.c
+++ b/src/x86/mc_init_tmpl.c
@@ -80,10 +80,13 @@
decl_mct_fn(dav1d_prep_bilin_avx2);
decl_mct_fn(dav1d_prep_bilin_ssse3);
+decl_avg_fn(dav1d_avg_avx512icl);
decl_avg_fn(dav1d_avg_avx2);
decl_avg_fn(dav1d_avg_ssse3);
+decl_w_avg_fn(dav1d_w_avg_avx512icl);
decl_w_avg_fn(dav1d_w_avg_avx2);
decl_w_avg_fn(dav1d_w_avg_ssse3);
+decl_mask_fn(dav1d_mask_avx512icl);
decl_mask_fn(dav1d_mask_avx2);
decl_mask_fn(dav1d_mask_ssse3);
decl_w_mask_fn(dav1d_w_mask_420_avx2);
@@ -172,10 +175,11 @@
c->warp8x8t = dav1d_warp_affine_8x8t_sse4;
#endif
+#if ARCH_X86_64
if (!(flags & DAV1D_X86_CPU_FLAG_AVX2))
return;
-#if BITDEPTH == 8 && ARCH_X86_64
+#if BITDEPTH == 8
init_mc_fn (FILTER_2D_8TAP_REGULAR, 8tap_regular, avx2);
init_mc_fn (FILTER_2D_8TAP_REGULAR_SMOOTH, 8tap_regular_smooth, avx2);
init_mc_fn (FILTER_2D_8TAP_REGULAR_SHARP, 8tap_regular_sharp, avx2);
@@ -217,7 +221,7 @@
if (!(flags & DAV1D_X86_CPU_FLAG_AVX512ICL))
return;
-#if BITDEPTH == 8 && ARCH_X86_64
+#if BITDEPTH == 8
init_mct_fn(FILTER_2D_8TAP_REGULAR, 8tap_regular, avx512icl);
init_mct_fn(FILTER_2D_8TAP_REGULAR_SMOOTH, 8tap_regular_smooth, avx512icl);
init_mct_fn(FILTER_2D_8TAP_REGULAR_SHARP, 8tap_regular_sharp, avx512icl);
@@ -228,5 +232,10 @@
init_mct_fn(FILTER_2D_8TAP_SHARP_SMOOTH, 8tap_sharp_smooth, avx512icl);
init_mct_fn(FILTER_2D_8TAP_SHARP, 8tap_sharp, avx512icl);
init_mct_fn(FILTER_2D_BILINEAR, bilin, avx512icl);
+
+ c->avg = dav1d_avg_avx512icl;
+ c->w_avg = dav1d_w_avg_avx512icl;
+ c->mask = dav1d_mask_avx512icl;
+#endif
#endif
}