shithub: dav1d

Download patch

ref: 7f2833a9911f566da4753335a219a2fcfe187f75
parent: a02ed9c64cf7c69cc53739357e3ca0c981125c0e
author: Ronald S. Bultje <[email protected]>
date: Wed Mar 25 13:05:23 EDT 2020

x86: add AVX2 SIMD for ipred.cfl_ac[444]

cfl_ac_444_w4_8bpc_c: 499.1
cfl_ac_444_w4_8bpc_ssse3: 24.3
cfl_ac_444_w4_8bpc_avx2: 28.9
cfl_ac_444_w8_8bpc_c: 1240.2
cfl_ac_444_w8_8bpc_ssse3: 47.4
cfl_ac_444_w8_8bpc_avx2: 34.9
cfl_ac_444_w16_8bpc_c: 1785.7
cfl_ac_444_w16_8bpc_ssse3: 86.7
cfl_ac_444_w16_8bpc_avx2: 54.6
cfl_ac_444_w32_8bpc_c: 4343.5
cfl_ac_444_w32_8bpc_ssse3: 236.5
cfl_ac_444_w32_8bpc_avx2: 113.6

--- a/src/x86/ipred.asm
+++ b/src/x86/ipred.asm
@@ -100,6 +100,8 @@
               db  6,  6,  6,  6,  2,  2,  2,  2,  4,  4,  4,  4;  0,  0,  0,  0
 pw_64:        times 2 dw 64
 
+cfl_ac_444_w16_pad1_shuffle: db 0, -1, 1, -1, 2, -1, 3, -1, 4, -1, 5, -1, 6, -1
+                             times 9 db 7, -1
 cfl_ac_w16_pad_shuffle: ; w=16, w_pad=1
                         db 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
                         ; w=8, w_pad=1 as well as second half of previous one
@@ -166,6 +168,7 @@
 JMP_TABLE ipred_cfl_left,   avx2, h4, h8, h16, h32
 JMP_TABLE ipred_cfl_ac_420, avx2, w16_pad1, w16_pad2, w16_pad3
 JMP_TABLE ipred_cfl_ac_422, avx2, w16_pad1, w16_pad2, w16_pad3
+JMP_TABLE ipred_cfl_ac_444, avx2, w32_pad1, w32_pad2, w32_pad3, w4, w8, w16, w32
 JMP_TABLE pal_pred,         avx2, w4, w8, w16, w32, w64
 
 cextern dr_intra_derivative
@@ -5035,6 +5038,236 @@
     vextracti128        xm1, m0, 1
     tzcnt               r1d, szd
     paddd               xm0, xm1
+    movd                xm2, r1d
+    movd                xm3, szd
+    punpckhqdq          xm1, xm0, xm0
+    paddd               xm0, xm1
+    psrad               xm3, 1
+    psrlq               xm1, xm0, 32
+    paddd               xm0, xm3
+    paddd               xm0, xm1
+    psrad               xm0, xm2
+    vpbroadcastw         m0, xm0
+.sub_loop:
+    mova                 m1, [ac_bakq]
+    psubw                m1, m0
+    mova          [ac_bakq], m1
+    add             ac_bakq, 32
+    sub                 szd, 16
+    jg .sub_loop
+    RET
+
+cglobal ipred_cfl_ac_444, 4, 9, 6, ac, y, stride, wpad, hpad, w, h, sz, ac_bak
+    movifnidn         hpadd, hpadm
+    movifnidn            wd, wm
+    mov                  hd, hm
+    mov                 szd, wd
+    imul                szd, hd
+    shl               hpadd, 2
+    sub                  hd, hpadd
+    pxor                 m4, m4
+    vpbroadcastd         m5, [pw_1]
+    tzcnt               r8d, wd
+    lea                  r5, [ipred_cfl_ac_444_avx2_table]
+    movsxd               r8, [r5+r8*4+12]
+    add                  r5, r8
+
+    DEFINE_ARGS ac, y, stride, wpad, hpad, stride3, h, sz, ac_bak
+    mov             ac_bakq, acq
+    jmp                  r5
+
+.w4:
+    lea            stride3q, [strideq*3]
+    pxor                xm2, xm2
+.w4_loop:
+    movd                xm1, [yq]
+    movd                xm0, [yq+strideq*2]
+    pinsrd              xm1, [yq+strideq], 1
+    pinsrd              xm0, [yq+stride3q], 1
+    punpcklbw           xm1, xm2
+    punpcklbw           xm0, xm2
+    psllw               xm1, 3
+    psllw               xm0, 3
+    mova              [acq], xm1
+    mova           [acq+16], xm0
+    paddw               xm1, xm0
+    paddw               xm4, xm1
+    lea                  yq, [yq+strideq*4]
+    add                 acq, 32
+    sub                  hd, 4
+    jg .w4_loop
+    test              hpadd, hpadd
+    jz .calc_avg_mul
+    pshufd              xm0, xm0, q3232
+    paddw               xm1, xm0, xm0
+.w4_hpad_loop:
+    mova              [acq], xm0
+    mova           [acq+16], xm0
+    paddw               xm4, xm1
+    add                 acq, 32
+    sub               hpadd, 4
+    jg .w4_hpad_loop
+    jmp .calc_avg_mul
+
+.w8:
+    lea            stride3q, [strideq*3]
+    pxor                 m2, m2
+.w8_loop:
+    movq                xm1, [yq]
+    movq                xm0, [yq+strideq*2]
+    vinserti128          m1, [yq+strideq], 1
+    vinserti128          m0, [yq+stride3q], 1
+    punpcklbw            m1, m2
+    punpcklbw            m0, m2
+    psllw                m1, 3
+    psllw                m0, 3
+    mova              [acq], m1
+    mova           [acq+32], m0
+    paddw                m1, m0
+    paddw                m4, m1
+    lea                  yq, [yq+strideq*4]
+    add                 acq, 64
+    sub                  hd, 4
+    jg .w8_loop
+    test              hpadd, hpadd
+    jz .calc_avg_mul
+    vpermq               m0, m0, q3232
+    paddw                m1, m0, m0
+.w8_hpad_loop:
+    mova              [acq], m0
+    mova           [acq+32], m0
+    paddw                m4, m1
+    add                 acq, 64
+    sub               hpadd, 4
+    jg .w8_hpad_loop
+    jmp .calc_avg_mul
+
+.w16:
+    test              wpadd, wpadd
+    jnz .w16_wpad
+.w16_loop:
+    pmovzxbw             m1, [yq]
+    pmovzxbw             m0, [yq+strideq]
+    psllw                m1, 3
+    psllw                m0, 3
+    mova              [acq], m1
+    mova           [acq+32], m0
+    paddw                m1, m0
+    pmaddwd              m1, m5
+    paddd                m4, m1
+    lea                  yq, [yq+strideq*2]
+    add                 acq, 64
+    sub                  hd, 2
+    jg .w16_loop
+    test              hpadd, hpadd
+    jz .calc_avg
+    jmp .w16_hpad
+.w16_wpad:
+    mova                 m3, [cfl_ac_444_w16_pad1_shuffle]
+.w16_wpad_loop:
+    vpbroadcastq         m1, [yq]
+    vpbroadcastq         m0, [yq+strideq]
+    pshufb               m1, m3
+    pshufb               m0, m3
+    psllw                m1, 3
+    psllw                m0, 3
+    mova              [acq], m1
+    mova           [acq+32], m0
+    paddw                m1, m0
+    pmaddwd              m1, m5
+    paddd                m4, m1
+    lea                  yq, [yq+strideq*2]
+    add                 acq, 64
+    sub                  hd, 2
+    jg .w16_wpad_loop
+    test              hpadd, hpadd
+    jz .calc_avg
+.w16_hpad:
+    paddw                m1, m0, m0
+    pmaddwd              m1, m5
+.w16_hpad_loop:
+    mova              [acq], m0
+    mova           [acq+32], m0
+    paddd                m4, m1
+    add                 acq, 64
+    sub               hpadd, 2
+    jg .w16_hpad_loop
+    jmp .calc_avg
+
+.w32:
+    test              wpadd, wpadd
+    jnz .w32_wpad
+.w32_loop:
+    pmovzxbw             m1, [yq]
+    pmovzxbw             m0, [yq+16]
+    psllw                m1, 3
+    psllw                m0, 3
+    mova              [acq], m1
+    mova           [acq+32], m0
+    paddw                m2, m1, m0
+    pmaddwd              m2, m5
+    paddd                m4, m2
+    add                  yq, strideq
+    add                 acq, 64
+    dec                  hd
+    jg .w32_loop
+    test              hpadd, hpadd
+    jz .calc_avg
+    jmp .w32_hpad_loop
+.w32_wpad:
+    DEFINE_ARGS ac, y, stride, wpad, hpad, iptr, h, sz, ac_bak
+    lea               iptrq, [ipred_cfl_ac_444_avx2_table]
+    add               wpadd, wpadd
+    mova                 m3, [iptrq+cfl_ac_444_w16_pad1_shuffle-ipred_cfl_ac_444_avx2_table]
+    movsxd            wpadq, [iptrq+wpadq+4]
+    add               iptrq, wpadq
+    jmp iptrq
+.w32_pad3:
+    vpbroadcastq         m1, [yq]
+    pshufb               m1, m3
+    vpermq               m0, m1, q3232
+    jmp .w32_wpad_end
+.w32_pad2:
+    pmovzxbw             m1, [yq]
+    pshufhw              m0, m1, q3333
+    vpermq               m0, m0, q3333
+    jmp .w32_wpad_end
+.w32_pad1:
+    pmovzxbw             m1, [yq]
+    vpbroadcastq         m0, [yq+16]
+    pshufb               m0, m3
+    ; fall-through
+.w32_wpad_end:
+    psllw                m1, 3
+    psllw                m0, 3
+    mova              [acq], m1
+    mova           [acq+32], m0
+    paddw                m2, m1, m0
+    pmaddwd              m2, m5
+    paddd                m4, m2
+    add                  yq, strideq
+    add                 acq, 64
+    dec                  hd
+    jz .w32_wpad_done
+    jmp iptrq
+.w32_wpad_done:
+    test              hpadd, hpadd
+    jz .calc_avg
+.w32_hpad_loop:
+    mova              [acq], m1
+    mova           [acq+32], m0
+    paddd                m4, m2
+    add                 acq, 64
+    dec               hpadd
+    jg .w32_hpad_loop
+    jmp .calc_avg
+
+.calc_avg_mul:
+    pmaddwd              m4, m5
+.calc_avg:
+    vextracti128        xm1, m4, 1
+    tzcnt               r1d, szd
+    paddd               xm0, xm4, xm1
     movd                xm2, r1d
     movd                xm3, szd
     punpckhqdq          xm1, xm0, xm0
--- a/src/x86/ipred_init_tmpl.c
+++ b/src/x86/ipred_init_tmpl.c
@@ -50,6 +50,7 @@
 
 decl_cfl_ac_fn(dav1d_ipred_cfl_ac_420_avx2);
 decl_cfl_ac_fn(dav1d_ipred_cfl_ac_422_avx2);
+decl_cfl_ac_fn(dav1d_ipred_cfl_ac_444_avx2);
 
 decl_pal_pred_fn(dav1d_pal_pred_avx2);
 
@@ -131,6 +132,7 @@
 
     c->cfl_ac[DAV1D_PIXEL_LAYOUT_I420 - 1] = dav1d_ipred_cfl_ac_420_avx2;
     c->cfl_ac[DAV1D_PIXEL_LAYOUT_I422 - 1] = dav1d_ipred_cfl_ac_422_avx2;
+    c->cfl_ac[DAV1D_PIXEL_LAYOUT_I444 - 1] = dav1d_ipred_cfl_ac_444_avx2;
 
     c->pal_pred = dav1d_pal_pred_avx2;
 #endif