shithub: dav1d

Download patch

ref: c204da0ff33a0d563d6c632b42799e4fbc48f402
parent: ac4e679bba97e915af98058831c6fa8d6254b5cd
author: Kyle Siefring <[email protected]>
date: Wed Mar 6 15:03:51 EST 2019

Use some 8 bit arithmetic in AVX2 CDEF filter

Before:
cdef_filter_8x8_8bpc_avx2: 252.3
cdef_filter_4x8_8bpc_avx2: 182.1
cdef_filter_4x4_8bpc_avx2: 105.7

After:
cdef_filter_8x8_8bpc_avx2: 235.5
cdef_filter_4x8_8bpc_avx2: 174.8
cdef_filter_4x4_8bpc_avx2: 101.8

--- a/src/x86/cdef.asm
+++ b/src/x86/cdef.asm
@@ -34,9 +34,13 @@
            dd 420, 210, 140, 105
 shufw_6543210x: db 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1, 14, 15
 shufw_210xxxxx: db 4, 5, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
+shufb_lohi: db 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15
 pw_128: times 2 dw 128
 pw_2048: times 2 dw 2048
-tap_table: dw 4, 2, 3, 3, 2, 1
+tap_table: ; masks for 8 bit shifts
+           db 0xFF, 0x7F, 0x3F, 0x1F, 0x0F, 0x07, 0x03, 0x01
+           ; weights
+           db 4, 2, 3, 3, 2, 1
            db -1 * 16 + 1, -2 * 16 + 2
            db  0 * 16 + 1, -1 * 16 + 2
            db  0 * 16 + 1,  0 * 16 + 2
@@ -55,29 +59,29 @@
 
 SECTION .text
 
-%macro ACCUMULATE_TAP 6 ; tap_offset, shift, strength, mul_tap, w, stride
+%macro ACCUMULATE_TAP 7 ; tap_offset, shift, mask, strength, mul_tap, w, stride
     ; load p0/p1
     movsx         offq, byte [dirq+kq+%1]       ; off1
-%if %5 == 4
-    movq           xm5, [stkq+offq*2+%6*0]      ; p0
-    movq           xm6, [stkq+offq*2+%6*2]
-    movhps         xm5, [stkq+offq*2+%6*1]
-    movhps         xm6, [stkq+offq*2+%6*3]
+%if %6 == 4
+    movq           xm5, [stkq+offq*2+%7*0]      ; p0
+    movq           xm6, [stkq+offq*2+%7*2]
+    movhps         xm5, [stkq+offq*2+%7*1]
+    movhps         xm6, [stkq+offq*2+%7*3]
     vinserti128     m5, xm6, 1
 %else
-    movu           xm5, [stkq+offq*2+%6*0]      ; p0
-    vinserti128     m5, [stkq+offq*2+%6*1], 1
+    movu           xm5, [stkq+offq*2+%7*0]      ; p0
+    vinserti128     m5, [stkq+offq*2+%7*1], 1
 %endif
     neg           offq                          ; -off1
-%if %5 == 4
-    movq           xm6, [stkq+offq*2+%6*0]      ; p1
-    movq           xm9, [stkq+offq*2+%6*2]
-    movhps         xm6, [stkq+offq*2+%6*1]
-    movhps         xm9, [stkq+offq*2+%6*3]
+%if %6 == 4
+    movq           xm6, [stkq+offq*2+%7*0]      ; p1
+    movq           xm9, [stkq+offq*2+%7*2]
+    movhps         xm6, [stkq+offq*2+%7*1]
+    movhps         xm9, [stkq+offq*2+%7*3]
     vinserti128     m6, xm9, 1
 %else
-    movu           xm6, [stkq+offq*2+%6*0]      ; p1
-    vinserti128     m6, [stkq+offq*2+%6*1], 1
+    movu           xm6, [stkq+offq*2+%7*0]      ; p1
+    vinserti128     m6, [stkq+offq*2+%7*1], 1
 %endif
     ; out of bounds values are set to a value that is a both a large unsigned
     ; value and a negative signed value.
@@ -88,24 +92,26 @@
     pminuw          m8, m6                      ; min after p1
 
     ; accumulate sum[m15] over p0/p1
+    ; calculate difference before converting
     psubw           m5, m4                      ; diff_p0(p0 - px)
     psubw           m6, m4                      ; diff_p1(p1 - px)
-    pabsw           m9, m5
-    pabsw          m10, m6
-    psignw         m11, %4, m5
-    psignw         m12, %4, m6
-    psrlw           m5, m9, %2
-    psrlw           m6, m10, %2
-    psubusw         m5, %3, m5
-    psubusw         m6, %3, m6
 
-    ; use unsigned min since abs diff can equal 0x8000
-    pminuw          m5, m9                      ; constrain(diff_p0)
-    pminuw          m6, m10                     ; constrain(diff_p1)
-    pmullw          m5, m11                     ; constrain(diff_p0) * taps
-    pmullw          m6, m12                     ; constrain(diff_p1) * taps
+    ; convert to 8-bits with signed saturation
+    ; saturating to large diffs has no impact on the results
+    packsswb        m5, m6
+
+    ; group into pairs so we can accumulate using maddubsw
+    pshufb          m5, m12
+    pabsb           m9, m5
+    psignb         m10, %5, m5
+    psrlw           m5, m9, %2                  ; emulate 8-bit shift
+    pand            m5, %3
+    psubusb         m5, %4, m5
+
+    ; use unsigned min since abs diff can equal 0x80
+    pminub          m5, m9
+    pmaddubsw       m5, m10
     paddw          m15, m5
-    paddw          m15, m6
 %endmacro
 
 %macro cdef_filter_fn 3 ; w, h, stride
@@ -359,6 +365,9 @@
     INIT_YMM avx2
     DEFINE_ARGS dst, stride, pridmp, damping, pri, sec, stride3, secdmp
 %undef edged
+    ; register to shuffle values into after packing
+    vbroadcasti128 m12, [shufb_lohi]
+
     movifnidn     prid, prim
     movifnidn     secd, secm
     mov       dampingd, r7m
@@ -379,21 +388,25 @@
     mov        [rsp+0], pridmpq                 ; pri_shift
     mov        [rsp+8], secdmpq                 ; sec_shift
 
+    DEFINE_ARGS dst, stride, pridmp, table, pri, sec, stride3, secdmp
+    lea         tableq, [tap_table]
+    vpbroadcastb   m13, [tableq+pridmpq]        ; pri_shift_mask
+    vpbroadcastb   m14, [tableq+secdmpq]        ; sec_shift_mask
+
     ; pri/sec_taps[k] [4 total]
-    DEFINE_ARGS dst, stride, tap, dummy, pri, sec, stride3
+    DEFINE_ARGS dst, stride, dummy, table, pri, sec, stride3
     movd           xm0, prid
     movd           xm1, secd
-    vpbroadcastw    m0, xm0                     ; pri_strength
-    vpbroadcastw    m1, xm1                     ; sec_strength
+    vpbroadcastb    m0, xm0                     ; pri_strength
+    vpbroadcastb    m1, xm1                     ; sec_strength
     and           prid, 1
-    lea           tapq, [tap_table]
-    lea           priq, [tapq+priq*4]           ; pri_taps
-    lea           secq, [tapq+8]                ; sec_taps
+    lea           priq, [tableq+priq*2+8]       ; pri_taps
+    lea           secq, [tableq+12]             ; sec_taps
 
     ; off1/2/3[k] [6 total] from [tapq+12+(dir+0/2/6)*2+k]
-    DEFINE_ARGS dst, stride, tap, dir, pri, sec, stride3
+    DEFINE_ARGS dst, stride, dir, tap, pri, sec, stride3
     mov           dird, r6m
-    lea           tapq, [tapq+dirq*2+12]
+    lea           dirq, [tapq+dirq*2+14]
 %if %1*%2*2/mmsize > 1
  %if %1 == 4
     DEFINE_ARGS dst, stride, dir, stk, pri, sec, stride3, h, off, k
@@ -405,7 +418,7 @@
     DEFINE_ARGS dst, stride, dir, stk, pri, sec, stride3, off, k
 %endif
     lea           stkq, [px]
-    pxor           m13, m13
+    pxor           m11, m11
 %if %1*%2*2/mmsize > 1
 .v_loop:
 %endif
@@ -424,20 +437,20 @@
     mova            m7, m4                      ; max
     mova            m8, m4                      ; min
 .k_loop:
-    vpbroadcastw    m2, [priq+kq*2]             ; pri_taps
-    vpbroadcastw    m3, [secq+kq*2]             ; sec_taps
+    vpbroadcastb    m2, [priq+kq]               ; pri_taps
+    vpbroadcastb    m3, [secq+kq]               ; sec_taps
 
-    ACCUMULATE_TAP 0*2, [rsp+0], m0, m2, %1, %3
-    ACCUMULATE_TAP 2*2, [rsp+8], m1, m3, %1, %3
-    ACCUMULATE_TAP 6*2, [rsp+8], m1, m3, %1, %3
+    ACCUMULATE_TAP 0*2, [rsp+0], m13, m0, m2, %1, %3
+    ACCUMULATE_TAP 2*2, [rsp+8], m14, m1, m3, %1, %3
+    ACCUMULATE_TAP 6*2, [rsp+8], m14, m1, m3, %1, %3
 
     dec             kq
     jge .k_loop
 
-    vpbroadcastd   m12, [pw_2048]
-    pcmpgtw        m11, m13, m15
-    paddw          m15, m11
-    pmulhrsw       m15, m12
+    vpbroadcastd   m10, [pw_2048]
+    pcmpgtw         m9, m11, m15
+    paddw          m15, m9
+    pmulhrsw       m15, m10
     paddw           m4, m15
     pminsw          m4, m7
     pmaxsw          m4, m8