shithub: dav1d

Download patch

ref: a20b5757c766999bf3078c6c186f93aefce1d59e
parent: 36d615d1204c126ad6a9fc560cde94dcd7bafb6d
author: Henrik Gramner <[email protected]>
date: Thu Oct 24 12:58:39 EDT 2019

x86: Fix overflows in inverse identity AVX2 transforms

--- a/src/x86/itx.asm
+++ b/src/x86/itx.asm
@@ -52,13 +52,15 @@
 pw_m3803_m6688: dw -3803, -6688
 pw_2896_m2896:  dw  2896, -2896
 
-pw_5:      times 2 dw 5
-pw_2048:   times 2 dw 2048
-pw_4096:   times 2 dw 4096
-pw_8192:   times 2 dw 8192
-pw_16384:  times 2 dw 16384
-pw_2896x8: times 2 dw 2896*8
-pw_5793x4: times 2 dw 5793*4
+pw_5:       times 2 dw 5
+pw_2048:    times 2 dw 2048
+pw_4096:    times 2 dw 4096
+pw_8192:    times 2 dw 8192
+pw_16384:   times 2 dw 16384
+pw_1697x16: times 2 dw 1697*16
+pw_1697x8:  times 2 dw 1697*8
+pw_2896x8:  times 2 dw 2896*8
+pw_5793x4:  times 2 dw 5793*4
 
 pd_2048: dd 2048
 
@@ -389,9 +391,9 @@
 %ifidn %1_%2, dct_identity
     vpbroadcastd         m0, [o(pw_2896x8)]
     pmulhrsw             m0, [cq]
-    vpbroadcastd         m1, [o(pw_5793x4)]
-    paddw                m0, m0
-    pmulhrsw             m0, m1
+    vpbroadcastd         m1, [o(pw_1697x8)]
+    pmulhrsw             m1, m0
+    paddw                m0, m1
     punpcklwd            m0, m0
     punpckhdq            m1, m0, m0
     punpckldq            m0, m0
@@ -399,12 +401,12 @@
 %elifidn %1_%2, identity_dct
     mova                 m0, [cq+16*0]
     packusdw             m0, [cq+16*1]
-    vpbroadcastd         m2, [o(pw_5793x4)]
-    vpbroadcastd         m3, [o(pw_2896x8)]
+    vpbroadcastd         m1, [o(pw_1697x8)]
+    vpbroadcastd         m2, [o(pw_2896x8)]
     packusdw             m0, m0
-    paddw                m0, m0
+    pmulhrsw             m1, m0
+    paddw                m0, m1
     pmulhrsw             m0, m2
-    pmulhrsw             m0, m3
     mova                 m1, m0
     jmp m(iadst_4x4_internal).end
 %elif %3 >= 0
@@ -556,11 +558,11 @@
 cglobal iidentity_4x4_internal, 0, 5, 6, dst, stride, c, eob, tx2
     mova                 m0, [cq+16*0]
     mova                 m1, [cq+16*1]
-    vpbroadcastd         m2, [o(pw_5793x4)]
-    paddw                m0, m0
-    paddw                m1, m1
-    pmulhrsw             m0, m2
-    pmulhrsw             m1, m2
+    vpbroadcastd         m3, [o(pw_1697x8)]
+    pmulhrsw             m2, m3, m0
+    pmulhrsw             m3, m1
+    paddw                m0, m2
+    paddw                m1, m3
     punpckhwd            m2, m0, m1
     punpcklwd            m0, m1
     punpckhwd            m1, m0, m2
@@ -567,11 +569,11 @@
     punpcklwd            m0, m2
     jmp                tx2q
 .pass2:
-    vpbroadcastd         m2, [o(pw_5793x4)]
-    paddw                m0, m0
-    paddw                m1, m1
-    pmulhrsw             m0, m2
-    pmulhrsw             m1, m2
+    vpbroadcastd         m3, [o(pw_1697x8)]
+    pmulhrsw             m2, m3, m0
+    pmulhrsw             m3, m1
+    paddw                m0, m2
+    paddw                m1, m3
     jmp m(iadst_4x4_internal).end
 
 %macro WRITE_4X8 2 ; coefs[1-2]
@@ -619,12 +621,12 @@
     movd                xm1, [cq+16*2]
     punpcklwd           xm1, [cq+16*3]
     vpbroadcastd        xm2, [o(pw_2896x8)]
-    vpbroadcastd        xm3, [o(pw_5793x4)]
+    vpbroadcastd        xm3, [o(pw_1697x8)]
     vpbroadcastd        xm4, [o(pw_2048)]
     punpckldq           xm0, xm1
     pmulhrsw            xm0, xm2
-    paddw               xm0, xm0
-    pmulhrsw            xm0, xm3
+    pmulhrsw            xm3, xm0
+    paddw               xm0, xm3
     pmulhrsw            xm0, xm2
     pmulhrsw            xm0, xm4
     vpbroadcastq         m0, xm0
@@ -896,7 +898,7 @@
     vpermq               m2, [cq+32*0], q3120
     vpermq               m0, [cq+32*1], q3120
     vpbroadcastd         m3, [o(pw_2896x8)]
-    vpbroadcastd         m4, [o(pw_5793x4)]
+    vpbroadcastd         m4, [o(pw_1697x8)]
     punpcklwd            m1, m2, m0
     punpckhwd            m2, m0
     pmulhrsw             m1, m3
@@ -903,10 +905,10 @@
     pmulhrsw             m2, m3
     punpcklwd            m0, m1, m2
     punpckhwd            m1, m2
-    paddw                m0, m0
-    paddw                m1, m1
-    pmulhrsw             m0, m4
-    pmulhrsw             m1, m4
+    pmulhrsw             m2, m4, m0
+    pmulhrsw             m4, m1
+    paddw                m0, m2
+    paddw                m1, m4
     jmp                tx2q
 .pass2:
     vpbroadcastd         m4, [o(pw_4096)]
@@ -919,11 +921,12 @@
     vpbroadcastd         m0, [o(pw_2896x8)]
     pmulhrsw             m0, [cq]
     vpbroadcastd         m1, [o(pw_16384)]
-    vpbroadcastd         m2, [o(pw_5793x4)]
+    vpbroadcastd         m2, [o(pw_1697x16)]
     vpbroadcastd         m3, [o(pw_2048)]
     pmulhrsw             m0, m1
-    psllw                m0, 2
-    pmulhrsw             m0, m2
+    pmulhrsw             m2, m0
+    paddw                m0, m0
+    paddw                m0, m2
     pmulhrsw             m3, m0
     punpcklwd            m1, m3, m3
     punpckhwd            m3, m3
@@ -937,12 +940,12 @@
     punpcklwd           xm0, [cq+32*1]
     movd                xm1, [cq+32*2]
     punpcklwd           xm1, [cq+32*3]
-    vpbroadcastd        xm2, [o(pw_5793x4)]
+    vpbroadcastd        xm2, [o(pw_1697x8)]
     vpbroadcastd        xm3, [o(pw_16384)]
     vpbroadcastd        xm4, [o(pw_2896x8)]
     punpckldq           xm0, xm1
-    paddw               xm0, xm0
-    pmulhrsw            xm0, xm2
+    pmulhrsw            xm2, xm0
+    paddw               xm0, xm2
     pmulhrsw            xm0, xm3
     psrlw               xm3, 3 ; pw_2048
     pmulhrsw            xm0, xm4
@@ -1281,13 +1284,19 @@
     mova                 m2, [cq+32*1]
     mova                 m4, [cq+32*2]
     mova                 m0, [cq+32*3]
-    vpbroadcastd         m5, [o(pw_5793x4)]
+    vpbroadcastd         m5, [o(pw_1697x8)]
     punpcklwd            m1, m3, m2
     punpckhwd            m3, m2
     punpcklwd            m2, m4, m0
     punpckhwd            m4, m0
-    REPX   {paddw    x, x }, m1, m2, m3, m4
-    REPX   {pmulhrsw x, m5}, m1, m2, m3, m4
+    pmulhrsw             m0, m5, m1
+    pmulhrsw             m6, m5, m2
+    pmulhrsw             m7, m5, m3
+    pmulhrsw             m5, m4
+    paddw                m1, m0
+    paddw                m2, m6
+    paddw                m3, m7
+    paddw                m4, m5
     vpbroadcastd         m5, [o(pw_16384)]
     punpckldq            m0, m1, m2
     punpckhdq            m1, m2
@@ -1296,10 +1305,17 @@
     REPX   {pmulhrsw x, m5}, m0, m1, m2, m3
     jmp                tx2q
 .pass2:
-    vpbroadcastd         m4, [o(pw_5793x4)]
+    vpbroadcastd         m8, [o(pw_1697x16)]
     vpbroadcastd         m5, [o(pw_2048)]
-    REPX   {psllw    x, 2 }, m0, m1, m2, m3
-    REPX   {pmulhrsw x, m4}, m0, m1, m2, m3
+    pmulhrsw             m4, m8, m0
+    pmulhrsw             m6, m8, m1
+    pmulhrsw             m7, m8, m2
+    pmulhrsw             m8, m3
+    REPX       {paddw x, x}, m0, m1, m2, m3
+    paddw                m0, m4
+    paddw                m1, m6
+    paddw                m2, m7
+    paddw                m3, m8
     jmp m(iadst_4x16_internal).end2
 
 %macro WRITE_8X4 4-7 strideq*1, strideq*2, r3, ; coefs[1-2], tmp[1-2], off[1-3]
@@ -1333,11 +1349,11 @@
 %ifidn %1_%2, dct_identity
     vpbroadcastd        xm0, [o(pw_2896x8)]
     pmulhrsw            xm1, xm0, [cq]
-    vpbroadcastd        xm2, [o(pw_5793x4)]
+    vpbroadcastd        xm2, [o(pw_1697x8)]
     vpbroadcastd        xm3, [o(pw_2048)]
     pmulhrsw            xm1, xm0
-    paddw               xm1, xm1
-    pmulhrsw            xm1, xm2
+    pmulhrsw            xm2, xm1
+    paddw               xm1, xm2
     pmulhrsw            xm1, xm3
     punpcklwd           xm1, xm1
     punpckldq           xm0, xm1, xm1
@@ -1508,11 +1524,11 @@
     paddw                m1, m1
     jmp                tx2q
 .pass2:
-    vpbroadcastd         m2, [o(pw_5793x4)]
-    paddw                m0, m0
-    paddw                m1, m1
-    pmulhrsw             m0, m2
-    pmulhrsw             m1, m2
+    vpbroadcastd         m3, [o(pw_1697x8)]
+    pmulhrsw             m2, m3, m0
+    pmulhrsw             m3, m1
+    paddw                m0, m2
+    paddw                m1, m3
     jmp m(iadst_8x4_internal).end
 
 %macro INV_TXFM_8X8_FN 2-3 -1 ; type1, type2, fast_thresh
@@ -1773,14 +1789,15 @@
     vpbroadcastd         m0, [o(pw_2896x8)]
     pmulhrsw             m7, m0, [cq]
     vpbroadcastd         m1, [o(pw_16384)]
-    vpbroadcastd         m2, [o(pw_5793x4)]
+    vpbroadcastd         m2, [o(pw_1697x16)]
     pxor                 m3, m3
     mova               [cq], m3
     pmulhrsw             m7, m0
     pmulhrsw             m7, m1
     psrlw                m1, 3 ; pw_2048
-    psllw                m7, 2
-    pmulhrsw             m7, m2
+    pmulhrsw             m2, m7
+    paddw                m7, m7
+    paddw                m7, m2
     pmulhrsw             m7, m1
     punpcklwd            m5, m7, m7
     punpckhwd            m7, m7
@@ -2101,6 +2118,16 @@
 INV_TXFM_8X16_FN identity, flipadst
 INV_TXFM_8X16_FN identity, identity
 
+%macro IDTX16 3-4 ; src/dst, tmp, pw_1697x16, [pw_16394]
+    pmulhrsw            m%2, m%3, m%1
+%if %0 == 4 ; if we're going to downshift by 1 doing so here eliminates the paddw
+    pmulhrsw            m%2, m%4
+%else
+    paddw               m%1, m%1
+%endif
+    paddw               m%1, m%2
+%endmacro
+
 cglobal iidentity_8x16_internal, 0, 5, 13, dst, stride, c, eob, tx2
     mova                xm3,     [cq+16*0]
     mova                xm2,     [cq+16*2]
@@ -2139,10 +2166,9 @@
     punpckhdq            m7, m8
     jmp                tx2q
 .pass2:
-    vpbroadcastd         m8, [o(pw_5793x4)]
-    REPX {psllw    x, 2       }, m0, m1, m2, m3, m4, m5, m6, m7
+    vpbroadcastd         m8, [o(pw_1697x16)]
     REPX {vpermq   x, x, q3120}, m0, m1, m2, m3, m4, m5, m6, m7
-    REPX {pmulhrsw x, m8      }, m0, m1, m2, m3, m4, m5, m6, m7
+    REPX {IDTX16   x, 9, 8}, 0, 1, 2, 3, 4, 5, 6, 7
     jmp m(idct_8x16_internal).end
 
 %macro WRITE_16X2 6 ; coefs[1-2], tmp[1-2], offset[1-2]
@@ -2171,11 +2197,11 @@
     vpbroadcastd        xm3, [o(pw_2896x8)]
     pmulhrsw            xm3, [cq]
     vpbroadcastd        xm0, [o(pw_16384)]
-    vpbroadcastd        xm1, [o(pw_5793x4)]
+    vpbroadcastd        xm1, [o(pw_1697x8)]
     pmulhrsw            xm3, xm0
     psrlw               xm0, 3 ; pw_2048
-    paddw               xm3, xm3
-    pmulhrsw            xm3, xm1
+    pmulhrsw            xm1, xm3
+    paddw               xm3, xm1
     pmulhrsw            xm3, xm0
     punpcklwd           xm3, xm3
     punpckldq           xm1, xm3, xm3
@@ -2194,15 +2220,15 @@
     mova                xm3,     [cq+16*3]
     vinserti128          m1, m1, [cq+16*6], 1
     vinserti128          m3, m3, [cq+16*7], 1
-    vpbroadcastd         m4, [o(pw_5793x4)]
+    vpbroadcastd         m4, [o(pw_1697x16)]
     vpbroadcastd         m5, [o(pw_16384)]
     packusdw             m0, m2
     packusdw             m1, m3
     packusdw             m0, m1
     vpbroadcastd         m1, [o(pw_2896x8)]
-    psllw                m0, 2
-    pmulhrsw             m0, m4
-    pmulhrsw             m0, m5
+    pmulhrsw             m4, m0
+    pmulhrsw             m4, m5
+    paddw                m0, m4
     psrlw                m5, 3 ; pw_2048
     pmulhrsw             m0, m1
     pmulhrsw             m0, m5
@@ -2462,28 +2488,40 @@
     mova                xm1,     [cq+16*3]
     vinserti128          m0, m0, [cq+16*6], 1
     vinserti128          m1, m1, [cq+16*7], 1
-    vpbroadcastd         m5, [o(pw_5793x4)]
+    vpbroadcastd         m7, [o(pw_1697x16)]
+    vpbroadcastd         m8, [o(pw_16384)]
     punpcklwd            m3, m2, m4
     punpckhwd            m2, m4
     punpcklwd            m4, m0, m1
     punpckhwd            m0, m1
-    REPX       {psllw x, 2}, m3, m2, m4, m0
     punpcklwd            m1, m3, m2
     punpckhwd            m3, m2
     punpcklwd            m2, m4, m0
     punpckhwd            m4, m0
-    REPX   {pmulhrsw x, m5}, m1, m3, m2, m4
-    vpbroadcastd         m5, [o(pw_16384)]
+    pmulhrsw             m0, m7, m1
+    pmulhrsw             m5, m7, m2
+    pmulhrsw             m6, m7, m3
+    pmulhrsw             m7, m4
+    REPX   {pmulhrsw x, m8}, m0, m5, m6, m7
+    paddw                m1, m0
+    paddw                m2, m5
+    paddw                m3, m6
+    paddw                m4, m7
     punpcklqdq           m0, m1, m2
     punpckhqdq           m1, m2
     punpcklqdq           m2, m3, m4
     punpckhqdq           m3, m4
-    REPX   {pmulhrsw x, m5}, m0, m1, m2, m3
     jmp                tx2q
 .pass2:
-    vpbroadcastd         m4, [o(pw_5793x4)]
-    REPX   {paddw    x, x }, m0, m1, m2, m3
-    REPX   {pmulhrsw x, m4}, m0, m1, m2, m3
+    vpbroadcastd         m7, [o(pw_1697x8)]
+    pmulhrsw             m4, m7, m0
+    pmulhrsw             m5, m7, m1
+    pmulhrsw             m6, m7, m2
+    pmulhrsw             m7, m3
+    paddw                m0, m4
+    paddw                m1, m5
+    paddw                m2, m6
+    paddw                m3, m7
     jmp m(iadst_16x4_internal).end
 
 %macro INV_TXFM_16X8_FN 2-3 -1 ; type1, type2, fast_thresh
@@ -2532,7 +2570,7 @@
     mova                 m3, [cq+32*6]
     packusdw             m3, [cq+32*7]
     vpbroadcastd         m4, [o(pw_2896x8)]
-    vpbroadcastd         m5, [o(pw_5793x4)]
+    vpbroadcastd         m5, [o(pw_1697x16)]
     packusdw             m0, m2
     packusdw             m1, m3
     vpbroadcastd         m2, [o(pw_16384)]
@@ -2541,9 +2579,9 @@
     vpermq               m0, m0, q1100
     punpcklwd            m0, m1
     pmulhrsw             m0, m4
-    psllw                m0, 2
-    pmulhrsw             m0, m5
-    pmulhrsw             m0, m2
+    pmulhrsw             m5, m0
+    pmulhrsw             m5, m2
+    paddw                m0, m5
     psrlw                m2, 3 ; pw_2048
     pmulhrsw             m0, m4
     pmulhrsw             m0, m2
@@ -2816,8 +2854,8 @@
     mova                xm1,     [cq-16*1]
     vinserti128          m0, m0, [cq+16*6], 1
     vinserti128          m1, m1, [cq+16*7], 1
-    vpbroadcastd         m9, [o(pw_5793x4)]
-    vpbroadcastd        m10, [o(pw_16384)]
+    vpbroadcastd        m10, [o(pw_1697x16)]
+    vpbroadcastd        m11, [o(pw_16384)]
     REPX   {pmulhrsw x, m3}, m7, m2, m6, m4, m8, m5, m0, m1
     punpcklwd            m3, m7, m2
     punpckhwd            m7, m2
@@ -2827,7 +2865,6 @@
     punpckhwd            m8, m5
     punpcklwd            m5, m0, m1
     punpckhwd            m0, m1
-    REPX       {psllw x, 2}, m3, m7, m2, m6, m4, m8, m5, m0
     punpckldq            m1, m3, m2
     punpckhdq            m3, m2
     punpckldq            m2, m4, m5
@@ -2836,7 +2873,7 @@
     punpckhdq            m7, m6
     punpckldq            m6, m8, m0
     punpckhdq            m8, m0
-    REPX   {pmulhrsw x, m9}, m1, m3, m2, m4, m5, m7, m6, m8
+    REPX {IDTX16 x, 0, 10, 11}, 1, 3, 2, 4, 5, 7, 6, 8
     punpcklqdq           m0, m1, m2
     punpckhqdq           m1, m2
     punpcklqdq           m2, m3, m4
@@ -2845,7 +2882,6 @@
     punpckhqdq           m5, m6
     punpcklqdq           m6, m7, m8
     punpckhqdq           m7, m8
-    REPX  {pmulhrsw x, m10}, m0, m1, m2, m3, m4, m5, m6, m7
     jmp                tx2q
 .pass2:
     vpbroadcastd         m8, [o(pw_4096)]
@@ -2916,14 +2952,15 @@
     vinserti128          m2, m2, [cq+32*14], 1
     vinserti128          m4, m4, [cq+32*15], 1
     punpcklwd            m1, m3
-    vpbroadcastd         m3, [o(pw_5793x4)]
+    vpbroadcastd         m3, [o(pw_1697x16)]
     punpcklwd            m2, m4
     vpbroadcastd         m4, [o(pw_8192)]
     punpckldq            m1, m2
     vpbroadcastd         m2, [o(pw_2896x8)]
     punpcklqdq           m0, m1
-    psllw                m0, 2
-    pmulhrsw             m0, m3
+    pmulhrsw             m3, m0
+    paddw                m0, m0
+    paddw                m0, m3
     pmulhrsw             m0, m4
     psrlw                m4, 2 ; pw_2048
     pmulhrsw             m0, m2
@@ -3352,47 +3389,47 @@
 INV_TXFM_16X16_FN identity, identity
 
 cglobal iidentity_16x16_internal, 0, 5, 16, 32*3, dst, stride, c, eob, tx2
-    mova                xm0,      [cq+16*0]
-    mova               xm15,      [cq+16*1]
-    mova                xm1,      [cq+16*2]
-    mova                xm8,      [cq+16*3]
-    mova                xm2,      [cq+16*4]
-    mova                xm9,      [cq+16*5]
-    mova                xm3,      [cq+16*6]
-    mova               xm10,      [cq+16*7]
+    vpbroadcastd         m7, [o(pw_1697x16)]
+    mova                xm0, [cq+16* 0]
+    vinserti128          m0, [cq+16*16], 1
+    mova               xm15, [cq+16* 1]
+    vinserti128         m15, [cq+16*17], 1
+    mova                xm1, [cq+16* 2]
+    vinserti128          m1, [cq+16*18], 1
+    mova                xm8, [cq+16* 3]
+    vinserti128          m8, [cq+16*19], 1
+    mova                xm2, [cq+16* 4]
+    vinserti128          m2, [cq+16*20], 1
+    mova                xm9, [cq+16* 5]
+    vinserti128          m9, [cq+16*21], 1
+    mova                xm3, [cq+16* 6]
+    vinserti128          m3, [cq+16*22], 1
+    mova               xm10, [cq+16* 7]
     add                  cq, 16*16
-    vinserti128          m0, m0,  [cq+16*0], 1
-    vinserti128         m15, m15, [cq+16*1], 1
-    mova                xm4,      [cq-16*8]
-    mova               xm11,      [cq-16*7]
-    vinserti128          m1, m1,  [cq+16*2], 1
-    vinserti128          m8, m8,  [cq+16*3], 1
-    mova                xm5,      [cq-16*6]
-    mova               xm12,      [cq-16*5]
-    vinserti128          m2, m2,  [cq+16*4], 1
-    vinserti128          m9, m9,  [cq+16*5], 1
-    mova                xm6,      [cq-16*4]
-    mova               xm13,      [cq-16*3]
-    vinserti128          m3, m3,  [cq+16*6], 1
-    vinserti128         m10, m10, [cq+16*7], 1
-    mova                xm7,      [cq-16*2]
-    mova               xm14,      [cq-16*1]
-    vinserti128          m4, m4,  [cq+16*8], 1
-    vinserti128         m11, m11, [cq+16*9], 1
-    vinserti128          m5, m5,  [cq+16*10], 1
-    vinserti128         m12, m12, [cq+16*11], 1
-    vinserti128          m6, m6,  [cq+16*12], 1
-    vinserti128         m13, m13, [cq+16*13], 1
-    vinserti128          m7, m7,  [cq+16*14], 1
-    vinserti128         m14, m14, [cq+16*15], 1
-    REPX   {psllw    x, 2 }, m0,  m1,  m2,  m3,  m4,  m5,  m6,  m7, \
-                             m8,  m9,  m10, m11, m12, m13, m14, m15
-    mova              [rsp], m0
-    vpbroadcastd         m0, [o(pw_5793x4)]
-    REPX   {pmulhrsw x, m0},      m1,  m2,  m3,  m4,  m5,  m6,  m7, \
-                             m8,  m9,  m10, m11, m12, m13, m14, m15
-    pmulhrsw             m0, [rsp]
+    vinserti128         m10, [cq+16* 7], 1
+    mova                xm4, [cq-16* 8]
+    vinserti128          m4, [cq+16* 8], 1
+    mova               xm11, [cq-16* 7]
+    vinserti128         m11, [cq+16* 9], 1
+    mova                xm5, [cq-16* 6]
+    vinserti128          m5, [cq+16*10], 1
+    mova               xm12, [cq-16* 5]
+    vinserti128         m12, [cq+16*11], 1
+    mova               xm13, [cq-16* 3]
+    vinserti128         m13, [cq+16*13], 1
+    mova               xm14, [cq-16* 1]
+    vinserti128         m14, [cq+16*15], 1
+    REPX   {IDTX16 x, 6, 7},  0, 15,  1,  8,  2,  9,  3, \
+                             10,  4, 11,  5, 12, 13, 14
+    mova                xm6, [cq-16* 4]
+    vinserti128          m6, [cq+16*12], 1
     mova              [rsp], m1
+    IDTX16                6, 1, 7
+    mova                xm1, [cq-16* 2]
+    vinserti128          m1, [cq+16*14], 1
+    pmulhrsw             m7, m1
+    paddw                m1, m1
+    paddw                m7, m1
     vpbroadcastd         m1, [o(pw_8192)]
     REPX   {pmulhrsw x, m1}, m0,       m2,  m3,  m4,  m5,  m6,  m7, \
                              m8,  m9,  m10, m11, m12, m13, m14, m15
@@ -3401,14 +3438,17 @@
     jmp m(idct_16x16_internal).pass1_end3
 ALIGN function_align
 .pass2:
-    vpbroadcastd        m15, [o(pw_5793x4)]
-    REPX  {psllw    x, 2  }, m0, m1, m2, m3, m4, m5, m6, m7
-    REPX  {pmulhrsw x, m15}, m0, m1, m2, m3, m4, m5, m6, m7
+    vpbroadcastd        m15, [o(pw_1697x16)]
+    mova         [rsp+32*1], m0
+    REPX  {IDTX16 x, 0, 15},  1,  2,  3,  4,  5,  6,  7, \
+                              8,  9, 10, 11, 12, 13, 14
+    mova                 m0, [rsp+32*1]
     mova         [rsp+32*1], m1
+    IDTX16                0, 1, 15
     mova                 m1, [rsp+32*0]
-    REPX  {psllw    x, 2  }, m8, m9, m10, m11, m12, m13, m14, m1
-    REPX  {pmulhrsw x, m15}, m8, m9, m10, m11, m12, m13, m14
     pmulhrsw            m15, m1
+    paddw                m1, m1
+    paddw               m15, m1
     jmp m(idct_16x16_internal).end
 
 %define o_base iadst4_dconly2a + 128
@@ -4606,7 +4646,7 @@
 %undef cmp
     lea                 rax, [o_base]
     vpbroadcastd         m9, [o(pw_2896x8)]
-    vpbroadcastd        m10, [o(pw_5793x4)]
+    vpbroadcastd        m10, [o(pw_1697x8)]
     vpbroadcastd        m11, [o(pw_2048)]
     cmp                eobd, 35  ; if (eob > 35)
     setg                r4b      ;   iteration_count++
@@ -4634,9 +4674,24 @@
     vinserti128          m6, m6, [cq+32*14], 1
     vinserti128          m7, m7, [cq+32*15], 1
     REPX  {pmulhrsw x, m9 }, m0, m1, m2, m3, m4, m5, m6, m7
-    REPX  {psllw    x, 3  }, m0, m1, m2, m3, m4, m5, m6, m7
+    REPX  {psllw    x, 2  }, m0, m1, m2, m3, m4, m5, m6, m7
     call m(inv_txfm_add_identity_identity_8x32).transpose8x8
-    REPX  {pmulhrsw x, m10}, m0, m1, m2, m3, m4, m5, m6, m7
+    pmulhrsw             m8, m10, m0
+    paddw                m0, m8
+    pmulhrsw             m8, m10, m1
+    paddw                m1, m8
+    pmulhrsw             m8, m10, m2
+    paddw                m2, m8
+    pmulhrsw             m8, m10, m3
+    paddw                m3, m8
+    pmulhrsw             m8, m10, m4
+    paddw                m4, m8
+    pmulhrsw             m8, m10, m5
+    paddw                m5, m8
+    pmulhrsw             m8, m10, m6
+    paddw                m6, m8
+    pmulhrsw             m8, m10, m7
+    paddw                m7, m8
     REPX  {pmulhrsw x, m11}, m0, m1, m2, m3, m4, m5, m6, m7
     WRITE_16X2            0,  1,  8,  0, strideq*0, strideq*1
     WRITE_16X2            2,  3,  0,  1, strideq*2, r3