shithub: dav1d

Download patch

ref: f16b43cdfa2f3f2d5af36185819bebf1ca9c806d
parent: 502204562c7458d42812dcf5c32b0cbc28150c25
author: Henrik Gramner <[email protected]>
date: Tue Jan 7 19:43:58 EST 2020

x86: Fix AVX2 inverse identity transform overflow/clipping

The coefficients after the first (8-bit) 1D identity transform may
require more than 16 bits precision before downshifting in some cases,
and they may also need to be clipped to int16_t after downshifting.

--- a/src/x86/itx.asm
+++ b/src/x86/itx.asm
@@ -60,7 +60,6 @@
 pw_1697x16: times 2 dw 1697*16
 pw_1697x8:  times 2 dw 1697*8
 pw_2896x8:  times 2 dw 2896*8
-pw_5793x4:  times 2 dw 5793*4
 
 pd_2048: dd 2048
 
@@ -393,7 +392,7 @@
     pmulhrsw             m0, [cq]
     vpbroadcastd         m1, [o(pw_1697x8)]
     pmulhrsw             m1, m0
-    paddw                m0, m1
+    paddsw               m0, m1
     punpcklwd            m0, m0
     punpckhdq            m1, m0, m0
     punpckldq            m0, m0
@@ -405,7 +404,7 @@
     vpbroadcastd         m2, [o(pw_2896x8)]
     packusdw             m0, m0
     pmulhrsw             m1, m0
-    paddw                m0, m1
+    paddsw               m0, m1
     pmulhrsw             m0, m2
     mova                 m1, m0
     jmp m(iadst_4x4_internal).end
@@ -561,8 +560,8 @@
     vpbroadcastd         m3, [o(pw_1697x8)]
     pmulhrsw             m2, m3, m0
     pmulhrsw             m3, m1
-    paddw                m0, m2
-    paddw                m1, m3
+    paddsw               m0, m2
+    paddsw               m1, m3
     punpckhwd            m2, m0, m1
     punpcklwd            m0, m1
     punpckhwd            m1, m0, m2
@@ -572,8 +571,8 @@
     vpbroadcastd         m3, [o(pw_1697x8)]
     pmulhrsw             m2, m3, m0
     pmulhrsw             m3, m1
-    paddw                m0, m2
-    paddw                m1, m3
+    paddsw               m0, m2
+    paddsw               m1, m3
     jmp m(iadst_4x4_internal).end
 
 %macro WRITE_4X8 2 ; coefs[1-2]
@@ -626,7 +625,7 @@
     punpckldq           xm0, xm1
     pmulhrsw            xm0, xm2
     pmulhrsw            xm3, xm0
-    paddw               xm0, xm3
+    paddsw              xm0, xm3
     pmulhrsw            xm0, xm2
     pmulhrsw            xm0, xm4
     vpbroadcastq         m0, xm0
@@ -907,8 +906,8 @@
     punpckhwd            m1, m2
     pmulhrsw             m2, m4, m0
     pmulhrsw             m4, m1
-    paddw                m0, m2
-    paddw                m1, m4
+    paddsw               m0, m2
+    paddsw               m1, m4
     jmp                tx2q
 .pass2:
     vpbroadcastd         m4, [o(pw_4096)]
@@ -925,8 +924,8 @@
     vpbroadcastd         m3, [o(pw_2048)]
     pmulhrsw             m0, m1
     pmulhrsw             m2, m0
-    paddw                m0, m0
-    paddw                m0, m2
+    paddsw               m0, m0
+    paddsw               m0, m2
     pmulhrsw             m3, m0
     punpcklwd            m1, m3, m3
     punpckhwd            m3, m3
@@ -941,15 +940,16 @@
     movd                xm1, [cq+32*2]
     punpcklwd           xm1, [cq+32*3]
     vpbroadcastd        xm2, [o(pw_1697x8)]
-    vpbroadcastd        xm3, [o(pw_16384)]
-    vpbroadcastd        xm4, [o(pw_2896x8)]
+    vpbroadcastd        xm3, [o(pw_2896x8)]
+    vpbroadcastd        xm4, [o(pw_2048)]
     punpckldq           xm0, xm1
+    pcmpeqw             xm1, xm1
     pmulhrsw            xm2, xm0
-    paddw               xm0, xm2
+    pcmpeqw             xm1, xm0
+    pxor                xm0, xm1
+    pavgw               xm0, xm2
     pmulhrsw            xm0, xm3
-    psrlw               xm3, 3 ; pw_2048
     pmulhrsw            xm0, xm4
-    pmulhrsw            xm0, xm3
     vpbroadcastq         m0, xm0
     mova                 m1, m0
     mova                 m2, m0
@@ -1283,26 +1283,33 @@
     mova                 m3, [cq+32*0]
     mova                 m2, [cq+32*1]
     mova                 m4, [cq+32*2]
-    mova                 m0, [cq+32*3]
-    vpbroadcastd         m5, [o(pw_1697x8)]
+    mova                 m5, [cq+32*3]
+    vpbroadcastd         m8, [o(pw_1697x8)]
+    pcmpeqw              m0, m0 ; -1
     punpcklwd            m1, m3, m2
     punpckhwd            m3, m2
-    punpcklwd            m2, m4, m0
-    punpckhwd            m4, m0
-    pmulhrsw             m0, m5, m1
-    pmulhrsw             m6, m5, m2
-    pmulhrsw             m7, m5, m3
-    pmulhrsw             m5, m4
-    paddw                m1, m0
-    paddw                m2, m6
-    paddw                m3, m7
-    paddw                m4, m5
-    vpbroadcastd         m5, [o(pw_16384)]
+    punpcklwd            m2, m4, m5
+    punpckhwd            m4, m5
+    pmulhrsw             m5, m8, m1
+    pmulhrsw             m6, m8, m2
+    pmulhrsw             m7, m8, m3
+    pmulhrsw             m8, m4
+    pcmpeqw              m9, m0, m1 ; we want to do a signed avg, but pavgw is
+    pxor                 m1, m9     ; unsigned. as long as both signs are equal
+    pcmpeqw              m9, m0, m2 ; it still works, but if the input is -1 the
+    pxor                 m2, m9     ; pmulhrsw result will become 0 which causes
+    pcmpeqw              m9, m0, m3 ; pavgw to output -32768 instead of 0 unless
+    pxor                 m3, m9     ; we explicitly deal with that case here.
+    pcmpeqw              m0, m4
+    pxor                 m4, m0
+    pavgw                m1, m5
+    pavgw                m2, m6
+    pavgw                m3, m7
+    pavgw                m4, m8
     punpckldq            m0, m1, m2
     punpckhdq            m1, m2
     punpckldq            m2, m3, m4
     punpckhdq            m3, m4
-    REPX   {pmulhrsw x, m5}, m0, m1, m2, m3
     jmp                tx2q
 .pass2:
     vpbroadcastd         m8, [o(pw_1697x16)]
@@ -1311,11 +1318,11 @@
     pmulhrsw             m6, m8, m1
     pmulhrsw             m7, m8, m2
     pmulhrsw             m8, m3
-    REPX       {paddw x, x}, m0, m1, m2, m3
-    paddw                m0, m4
-    paddw                m1, m6
-    paddw                m2, m7
-    paddw                m3, m8
+    REPX      {paddsw x, x}, m0, m1, m2, m3
+    paddsw               m0, m4
+    paddsw               m1, m6
+    paddsw               m2, m7
+    paddsw               m3, m8
     jmp m(iadst_4x16_internal).end2
 
 %macro WRITE_8X4 4-7 strideq*1, strideq*2, r3, ; coefs[1-2], tmp[1-2], off[1-3]
@@ -1353,7 +1360,7 @@
     vpbroadcastd        xm3, [o(pw_2048)]
     pmulhrsw            xm1, xm0
     pmulhrsw            xm2, xm1
-    paddw               xm1, xm2
+    paddsw              xm1, xm2
     pmulhrsw            xm1, xm3
     punpcklwd           xm1, xm1
     punpckldq           xm0, xm1, xm1
@@ -1369,7 +1376,7 @@
     vpbroadcastd        xm3, [o(pw_2048)]
     packusdw            xm0, xm1
     pmulhrsw            xm0, xm2
-    paddw               xm0, xm0
+    paddsw              xm0, xm0
     pmulhrsw            xm0, xm2
     pmulhrsw            xm0, xm3
     vinserti128          m0, m0, xm0, 1
@@ -1520,15 +1527,15 @@
     pmulhrsw             m2, m3
     punpcklwd            m0, m1, m2
     punpckhwd            m1, m2
-    paddw                m0, m0
-    paddw                m1, m1
+    paddsw               m0, m0
+    paddsw               m1, m1
     jmp                tx2q
 .pass2:
     vpbroadcastd         m3, [o(pw_1697x8)]
     pmulhrsw             m2, m3, m0
     pmulhrsw             m3, m1
-    paddw                m0, m2
-    paddw                m1, m3
+    paddsw               m0, m2
+    paddsw               m1, m3
     jmp m(iadst_8x4_internal).end
 
 %macro INV_TXFM_8X8_FN 2-3 -1 ; type1, type2, fast_thresh
@@ -1796,8 +1803,8 @@
     pmulhrsw             m7, m1
     psrlw                m1, 3 ; pw_2048
     pmulhrsw             m2, m7
-    paddw                m7, m7
-    paddw                m7, m2
+    paddsw               m7, m7
+    paddsw               m7, m2
     pmulhrsw             m7, m1
     punpcklwd            m5, m7, m7
     punpckhwd            m7, m7
@@ -2120,12 +2127,12 @@
 
 %macro IDTX16 3-4 ; src/dst, tmp, pw_1697x16, [pw_16394]
     pmulhrsw            m%2, m%3, m%1
-%if %0 == 4 ; if we're going to downshift by 1 doing so here eliminates the paddw
+%if %0 == 4 ; if downshifting by 1
     pmulhrsw            m%2, m%4
 %else
-    paddw               m%1, m%1
+    paddsw              m%1, m%1
 %endif
-    paddw               m%1, m%2
+    paddsw              m%1, m%2
 %endmacro
 
 cglobal iidentity_8x16_internal, 0, 5, 13, dst, stride, c, eob, tx2
@@ -2201,7 +2208,7 @@
     pmulhrsw            xm3, xm0
     psrlw               xm0, 3 ; pw_2048
     pmulhrsw            xm1, xm3
-    paddw               xm3, xm1
+    paddsw              xm3, xm1
     pmulhrsw            xm3, xm0
     punpcklwd           xm3, xm3
     punpckldq           xm1, xm3, xm3
@@ -2228,7 +2235,7 @@
     vpbroadcastd         m1, [o(pw_2896x8)]
     pmulhrsw             m4, m0
     pmulhrsw             m4, m5
-    paddw                m0, m4
+    paddsw               m0, m4
     psrlw                m5, 3 ; pw_2048
     pmulhrsw             m0, m1
     pmulhrsw             m0, m5
@@ -2503,10 +2510,10 @@
     pmulhrsw             m6, m7, m3
     pmulhrsw             m7, m4
     REPX   {pmulhrsw x, m8}, m0, m5, m6, m7
-    paddw                m1, m0
-    paddw                m2, m5
-    paddw                m3, m6
-    paddw                m4, m7
+    paddsw               m1, m0
+    paddsw               m2, m5
+    paddsw               m3, m6
+    paddsw               m4, m7
     punpcklqdq           m0, m1, m2
     punpckhqdq           m1, m2
     punpcklqdq           m2, m3, m4
@@ -2518,10 +2525,10 @@
     pmulhrsw             m5, m7, m1
     pmulhrsw             m6, m7, m2
     pmulhrsw             m7, m3
-    paddw                m0, m4
-    paddw                m1, m5
-    paddw                m2, m6
-    paddw                m3, m7
+    paddsw               m0, m4
+    paddsw               m1, m5
+    paddsw               m2, m6
+    paddsw               m3, m7
     jmp m(iadst_16x4_internal).end
 
 %macro INV_TXFM_16X8_FN 2-3 -1 ; type1, type2, fast_thresh
@@ -2581,7 +2588,7 @@
     pmulhrsw             m0, m4
     pmulhrsw             m5, m0
     pmulhrsw             m5, m2
-    paddw                m0, m5
+    paddsw               m0, m5
     psrlw                m2, 3 ; pw_2048
     pmulhrsw             m0, m4
     pmulhrsw             m0, m2
@@ -2903,7 +2910,7 @@
     vpbroadcastd         m3, [o(pw_2896x8)]
     pmulhrsw             m3, [cq]
     vpbroadcastd         m0, [o(pw_8192)]
-    vpbroadcastd         m1, [o(pw_5793x4)]
+    vpbroadcastd         m1, [o(pw_1697x16)]
     vpbroadcastw         m4, [o(deint_shuf)] ; pb_0_1
     pcmpeqb              m5, m5
     pxor                 m6, m6
@@ -2911,8 +2918,7 @@
     paddb                m5, m5 ; pb_m2
     pmulhrsw             m3, m0
     psrlw                m0, 2  ; pw_2048
-    psllw                m3, 2
-    pmulhrsw             m3, m1
+    IDTX16                3, 1, 1
     pmulhrsw             m3, m0
     mov                 r3d, 8
 .loop:
@@ -2954,17 +2960,15 @@
     punpcklwd            m1, m3
     vpbroadcastd         m3, [o(pw_1697x16)]
     punpcklwd            m2, m4
-    vpbroadcastd         m4, [o(pw_8192)]
+    vpbroadcastd         m4, [o(pw_2896x8)]
     punpckldq            m1, m2
-    vpbroadcastd         m2, [o(pw_2896x8)]
+    vpbroadcastd         m2, [o(pw_2048)]
     punpcklqdq           m0, m1
     pmulhrsw             m3, m0
-    paddw                m0, m0
-    paddw                m0, m3
+    psraw                m3, 1
+    pavgw                m0, m3
     pmulhrsw             m0, m4
-    psrlw                m4, 2 ; pw_2048
     pmulhrsw             m0, m2
-    pmulhrsw             m0, m4
     mov                 r3d, 8
     jmp m(inv_txfm_add_identity_dct_16x4).end
 %endif
@@ -3385,6 +3389,12 @@
     WRITE_16X2            7, [rsp+32*2],  0,  1, strideq*2, r3
     jmp m(idct_16x16_internal).end3
 
+%macro IDTX16B 3 ; src/dst, tmp, pw_1697x16
+    pmulhrsw            m%2, m%3, m%1
+    psraw               m%2, 1
+    pavgw               m%1, m%2 ; signs are guaranteed to be equal
+%endmacro
+
 INV_TXFM_16X16_FN identity, dct,      15
 INV_TXFM_16X16_FN identity, identity
 
@@ -3419,22 +3429,17 @@
     vinserti128         m13, [cq+16*13], 1
     mova               xm14, [cq-16* 1]
     vinserti128         m14, [cq+16*15], 1
-    REPX   {IDTX16 x, 6, 7},  0, 15,  1,  8,  2,  9,  3, \
+    REPX  {IDTX16B x, 6, 7},  0, 15,  1,  8,  2,  9,  3, \
                              10,  4, 11,  5, 12, 13, 14
     mova                xm6, [cq-16* 4]
     vinserti128          m6, [cq+16*12], 1
-    mova              [rsp], m1
-    IDTX16                6, 1, 7
-    mova                xm1, [cq-16* 2]
-    vinserti128          m1, [cq+16*14], 1
-    pmulhrsw             m7, m1
-    paddw                m1, m1
-    paddw                m7, m1
-    vpbroadcastd         m1, [o(pw_8192)]
-    REPX   {pmulhrsw x, m1}, m0,       m2,  m3,  m4,  m5,  m6,  m7, \
-                             m8,  m9,  m10, m11, m12, m13, m14, m15
-    pmulhrsw             m1, [rsp]
     mova              [rsp], m0
+    IDTX16B               6, 0, 7
+    mova                xm0, [cq-16* 2]
+    vinserti128          m0, [cq+16*14], 1
+    pmulhrsw             m7, m0
+    psraw                m7, 1
+    pavgw                m7, m0
     jmp m(idct_16x16_internal).pass1_end3
 ALIGN function_align
 .pass2:
@@ -3963,7 +3968,7 @@
     vinserti128          m6, m6, [cq+16* 9], 1
     vinserti128          m7, m7, [cq+16*13], 1
     REPX {mova [cq+32*x], m8}, -4, -2,  0,  2,  4,  6
-    REPX  {paddw     x, m9}, m0, m1, m2, m3, m4, m5, m6, m7
+    REPX  {paddsw    x, m9}, m0, m1, m2, m3, m4, m5, m6, m7
     call .transpose8x8
     REPX  {psraw     x, 3 }, m0, m1, m2, m3, m4, m5, m6, m7
     WRITE_8X4             0,  4,  8, 10, strideq*8, strideq*4, r4*4
@@ -4572,12 +4577,12 @@
     IDCT32_PASS1_END      1,  9,  6,  7
     ret
 
-cglobal inv_txfm_add_identity_identity_16x32, 4, 5, 12, dst, stride, c, eob
+cglobal inv_txfm_add_identity_identity_16x32, 4, 5, 13, dst, stride, c, eob
 %undef cmp
     lea                 rax, [o_base]
     vpbroadcastd         m9, [o(pw_2896x8)]
-    vpbroadcastd        m10, [o(pw_5793x4)]
-    vpbroadcastd        m11, [o(pw_5)]
+    vpbroadcastd        m10, [o(pw_1697x16)]
+    vpbroadcastd        m12, [o(pw_8192)]
     cmp                eobd, 43   ; if (eob > 43)
     setg                r4b       ;   iteration_count++
     cmp                eobd, 150  ; if (eob > 150)
@@ -4586,6 +4591,7 @@
     adc                 r4b, al   ;   iteration_count++
     lea                  r3, [strideq*3]
     mov                 rax, cq
+    paddw               m11, m12, m12 ; pw_16384
 .loop:
     mova                xm0,     [cq+64* 0]
     mova                xm1,     [cq+64* 1]
@@ -4604,11 +4610,9 @@
     vinserti128          m6, m6, [cq+64*14], 1
     vinserti128          m7, m7, [cq+64*15], 1
     REPX  {pmulhrsw x, m9 }, m0, m1, m2, m3, m4, m5, m6, m7
-    REPX  {psllw    x, 2  }, m0, m1, m2, m3, m4, m5, m6, m7
+    REPX  {IDTX16 x, 8, 10, 11}, 0, 1, 2, 3, 4, 5, 6, 7
     call m(inv_txfm_add_identity_identity_8x32).transpose8x8
-    REPX  {pmulhrsw x, m10}, m0, m1, m2, m3, m4, m5, m6, m7
-    REPX  {paddw    x, m11}, m0, m1, m2, m3, m4, m5, m6, m7
-    REPX  {psraw    x, 3  }, m0, m1, m2, m3, m4, m5, m6, m7
+    REPX  {pmulhrsw x, m12}, m0, m1, m2, m3, m4, m5, m6, m7
     WRITE_16X2            0,  1,  8,  0, strideq*0, strideq*1
     WRITE_16X2            2,  3,  0,  1, strideq*2, r3
     lea                dstq, [dstq+strideq*4]
@@ -4646,7 +4650,7 @@
 %undef cmp
     lea                 rax, [o_base]
     vpbroadcastd         m9, [o(pw_2896x8)]
-    vpbroadcastd        m10, [o(pw_1697x8)]
+    vpbroadcastd        m10, [o(pw_1697x16)]
     vpbroadcastd        m11, [o(pw_2048)]
     cmp                eobd, 35  ; if (eob > 35)
     setg                r4b      ;   iteration_count++
@@ -4674,24 +4678,9 @@
     vinserti128          m6, m6, [cq+32*14], 1
     vinserti128          m7, m7, [cq+32*15], 1
     REPX  {pmulhrsw x, m9 }, m0, m1, m2, m3, m4, m5, m6, m7
-    REPX  {psllw    x, 2  }, m0, m1, m2, m3, m4, m5, m6, m7
+    REPX  {paddsw   x, x  }, m0, m1, m2, m3, m4, m5, m6, m7
     call m(inv_txfm_add_identity_identity_8x32).transpose8x8
-    pmulhrsw             m8, m10, m0
-    paddw                m0, m8
-    pmulhrsw             m8, m10, m1
-    paddw                m1, m8
-    pmulhrsw             m8, m10, m2
-    paddw                m2, m8
-    pmulhrsw             m8, m10, m3
-    paddw                m3, m8
-    pmulhrsw             m8, m10, m4
-    paddw                m4, m8
-    pmulhrsw             m8, m10, m5
-    paddw                m5, m8
-    pmulhrsw             m8, m10, m6
-    paddw                m6, m8
-    pmulhrsw             m8, m10, m7
-    paddw                m7, m8
+    REPX  {IDTX16 x, 8, 10}, 0, 1, 2, 3, 4, 5, 6, 7
     REPX  {pmulhrsw x, m11}, m0, m1, m2, m3, m4, m5, m6, m7
     WRITE_16X2            0,  1,  8,  0, strideq*0, strideq*1
     WRITE_16X2            2,  3,  0,  1, strideq*2, r3