shithub: dav1d

Download patch

ref: a47212259e2c6ca44b2ec502099ff33247684b52
parent: c37b5ee3ce9a489cbc3db4d30fa552722da06455
author: Henrik Gramner <[email protected]>
date: Tue Jan 7 19:44:35 EST 2020

Add misc. inverse transform C optimizations

--- a/src/itx_1d.c
+++ b/src/itx_1d.c
@@ -62,42 +62,63 @@
  * wrap around.
  */
 
-void dav1d_inv_dct4_1d_c(const int32_t *const in, const ptrdiff_t in_s,
-                         int32_t *const out, const ptrdiff_t out_s,
-                         const int max)
+static NOINLINE void
+inv_dct4_1d_internal_c(int32_t *const c, const ptrdiff_t stride,
+                       const int min, const int max, const int tx64)
 {
-    const int min = -max - 1;
-    const int in0 = in[0 * in_s], in1 = in[1 * in_s];
-    const int in2 = in[2 * in_s], in3 = in[3 * in_s];
+    assert(stride > 0);
+    const int in0 = c[0 * stride], in1 = c[1 * stride];
 
-    int t0 = ((in0 + in2) * 181 + 128) >> 8;
-    int t1 = ((in0 - in2) * 181 + 128) >> 8;
-    int t2 = ((in1 *  1567         - in3 * (3784 - 4096) + 2048) >> 12) - in3;
-    int t3 = ((in1 * (3784 - 4096) + in3 *  1567         + 2048) >> 12) + in1;
+    int t0, t1, t2, t3;
+    if (tx64) {
+        t0 = t1 = (in0 * 181 + 128) >> 8;
+        t2 = (in1 * 1567 + 2048) >> 12;
+        t3 = (in1 * 3784 + 2048) >> 12;
+    } else {
+        const int in2 = c[2 * stride], in3 = c[3 * stride];
 
-    out[0 * out_s] = CLIP(t0 + t3);
-    out[1 * out_s] = CLIP(t1 + t2);
-    out[2 * out_s] = CLIP(t1 - t2);
-    out[3 * out_s] = CLIP(t0 - t3);
+        t0 = ((in0 + in2) * 181 + 128) >> 8;
+        t1 = ((in0 - in2) * 181 + 128) >> 8;
+        t2 = ((in1 *  1567         - in3 * (3784 - 4096) + 2048) >> 12) - in3;
+        t3 = ((in1 * (3784 - 4096) + in3 *  1567         + 2048) >> 12) + in1;
+    }
+
+    c[0 * stride] = CLIP(t0 + t3);
+    c[1 * stride] = CLIP(t1 + t2);
+    c[2 * stride] = CLIP(t1 - t2);
+    c[3 * stride] = CLIP(t0 - t3);
 }
 
-void dav1d_inv_dct8_1d_c(const int32_t *const in, const ptrdiff_t in_s,
-                         int32_t *const out, const ptrdiff_t out_s,
-                         const int max)
+void dav1d_inv_dct4_1d_c(int32_t *const c, const ptrdiff_t stride,
+                         const int min, const int max)
 {
-    const int min = -max - 1;
-    int32_t tmp[4];
+    inv_dct4_1d_internal_c(c, stride, min, max, 0);
+}
 
-    dav1d_inv_dct4_1d_c(in, in_s * 2, tmp, 1, max);
+static NOINLINE void
+inv_dct8_1d_internal_c(int32_t *const c, const ptrdiff_t stride,
+                       const int min, const int max, const int tx64)
+{
+    assert(stride > 0);
+    inv_dct4_1d_internal_c(c, stride << 1, min, max, tx64);
 
-    const int in1 = in[1 * in_s], in3 = in[3 * in_s];
-    const int in5 = in[5 * in_s], in7 = in[7 * in_s];
+    const int in1 = c[1 * stride], in3 = c[3 * stride];
 
-    int t4a = ((in1 *   799         - in7 * (4017 - 4096) + 2048) >> 12) - in7;
-    int t5a =  (in5 *  1703         - in3 *  1138         + 1024) >> 11;
-    int t6a =  (in5 *  1138         + in3 *  1703         + 1024) >> 11;
-    int t7a = ((in1 * (4017 - 4096) + in7 *  799          + 2048) >> 12) + in1;
+    int t4a, t5a, t6a, t7a;
+    if (tx64) {
+        t4a = (in1 *   799 + 2048) >> 12;
+        t5a = (in3 * -2276 + 2048) >> 12;
+        t6a = (in3 *  3406 + 2048) >> 12;
+        t7a = (in1 *  4017 + 2048) >> 12;
+    } else {
+        const int in5 = c[5 * stride], in7 = c[7 * stride];
 
+        t4a = ((in1 *   799         - in7 * (4017 - 4096) + 2048) >> 12) - in7;
+        t5a =  (in5 *  1703         - in3 *  1138         + 1024) >> 11;
+        t6a =  (in5 *  1138         + in3 *  1703         + 1024) >> 11;
+        t7a = ((in1 * (4017 - 4096) + in7 *  799          + 2048) >> 12) + in1;
+    }
+
     int t4  = CLIP(t4a + t5a);
         t5a = CLIP(t4a - t5a);
     int t7  = CLIP(t7a + t6a);
@@ -106,39 +127,61 @@
     int t5  = ((t6a - t5a) * 181 + 128) >> 8;
     int t6  = ((t6a + t5a) * 181 + 128) >> 8;
 
-    out[0 * out_s] = CLIP(tmp[0] + t7);
-    out[1 * out_s] = CLIP(tmp[1] + t6);
-    out[2 * out_s] = CLIP(tmp[2] + t5);
-    out[3 * out_s] = CLIP(tmp[3] + t4);
-    out[4 * out_s] = CLIP(tmp[3] - t4);
-    out[5 * out_s] = CLIP(tmp[2] - t5);
-    out[6 * out_s] = CLIP(tmp[1] - t6);
-    out[7 * out_s] = CLIP(tmp[0] - t7);
+    const int t0 = c[0 * stride];
+    const int t1 = c[2 * stride];
+    const int t2 = c[4 * stride];
+    const int t3 = c[6 * stride];
+
+    c[0 * stride] = CLIP(t0 + t7);
+    c[1 * stride] = CLIP(t1 + t6);
+    c[2 * stride] = CLIP(t2 + t5);
+    c[3 * stride] = CLIP(t3 + t4);
+    c[4 * stride] = CLIP(t3 - t4);
+    c[5 * stride] = CLIP(t2 - t5);
+    c[6 * stride] = CLIP(t1 - t6);
+    c[7 * stride] = CLIP(t0 - t7);
 }
 
-void dav1d_inv_dct16_1d_c(const int32_t *const in, const ptrdiff_t in_s,
-                          int32_t *const out, const ptrdiff_t out_s,
-                          const int max)
+void dav1d_inv_dct8_1d_c(int32_t *const c, const ptrdiff_t stride,
+                         const int min, const int max)
 {
-    const int min = -max - 1;
-    int32_t tmp[8];
+    inv_dct8_1d_internal_c(c, stride, min, max, 0);
+}
 
-    dav1d_inv_dct8_1d_c(in, in_s * 2, tmp, 1, max);
+static NOINLINE void
+inv_dct16_1d_internal_c(int32_t *const c, const ptrdiff_t stride,
+                        const int min, const int max, int tx64)
+{
+    assert(stride > 0);
+    inv_dct8_1d_internal_c(c, stride << 1, min, max, tx64);
 
-    const int in1  = in[ 1 * in_s], in3  = in[ 3 * in_s];
-    const int in5  = in[ 5 * in_s], in7  = in[ 7 * in_s];
-    const int in9  = in[ 9 * in_s], in11 = in[11 * in_s];
-    const int in13 = in[13 * in_s], in15 = in[15 * in_s];
+    const int in1 = c[1 * stride], in3 = c[3 * stride];
+    const int in5 = c[5 * stride], in7 = c[7 * stride];
 
-    int t8a  = ((in1  *   401         - in15 * (4076 - 4096) + 2048) >> 12) - in15;
-    int t15a = ((in1  * (4076 - 4096) + in15 *   401         + 2048) >> 12) + in1;
-    int t9a  =  (in9  *  1583         - in7  *  1299         + 1024) >> 11;
-    int t14a =  (in9  *  1299         + in7  *  1583         + 1024) >> 11;
-    int t10a = ((in5  *  1931         - in11 * (3612 - 4096) + 2048) >> 12) - in11;
-    int t13a = ((in5  * (3612 - 4096) + in11 *  1931         + 2048) >> 12) + in5;
-    int t11a = ((in13 * (3920 - 4096) - in3  *  1189         + 2048) >> 12) + in13;
-    int t12a = ((in13 *  1189         + in3  * (3920 - 4096) + 2048) >> 12) + in3;
+    int t8a, t9a, t10a, t11a, t12a, t13a, t14a, t15a;
+    if (tx64) {
+        t8a  = (in1 *   401 + 2048) >> 12;
+        t9a  = (in7 * -2598 + 2048) >> 12;
+        t10a = (in5 *  1931 + 2048) >> 12;
+        t11a = (in3 * -1189 + 2048) >> 12;
+        t12a = (in3 *  3920 + 2048) >> 12;
+        t13a = (in5 *  3612 + 2048) >> 12;
+        t14a = (in7 *  3166 + 2048) >> 12;
+        t15a = (in1 *  4076 + 2048) >> 12;
+    } else {
+        const int in9  = c[ 9 * stride], in11 = c[11 * stride];
+        const int in13 = c[13 * stride], in15 = c[15 * stride];
 
+        t8a  = ((in1  *   401         - in15 * (4076 - 4096) + 2048) >> 12) - in15;
+        t9a  =  (in9  *  1583         - in7  *  1299         + 1024) >> 11;
+        t10a = ((in5  *  1931         - in11 * (3612 - 4096) + 2048) >> 12) - in11;
+        t11a = ((in13 * (3920 - 4096) - in3  *  1189         + 2048) >> 12) + in13;
+        t12a = ((in13 *  1189         + in3  * (3920 - 4096) + 2048) >> 12) + in3;
+        t13a = ((in5  * (3612 - 4096) + in11 *  1931         + 2048) >> 12) + in5;
+        t14a =  (in9  *  1299         + in7  *  1583         + 1024) >> 11;
+        t15a = ((in1  * (4076 - 4096) + in15 *   401         + 2048) >> 12) + in1;
+    }
+
     int t8  = CLIP(t8a  + t9a);
     int t9  = CLIP(t8a  - t9a);
     int t10 = CLIP(t11a - t10a);
@@ -167,59 +210,94 @@
     t11  = ((t12a - t11a) * 181 + 128) >> 8;
     t12  = ((t12a + t11a) * 181 + 128) >> 8;
 
-    out[ 0 * out_s] = CLIP(tmp[0] + t15a);
-    out[ 1 * out_s] = CLIP(tmp[1] + t14);
-    out[ 2 * out_s] = CLIP(tmp[2] + t13a);
-    out[ 3 * out_s] = CLIP(tmp[3] + t12);
-    out[ 4 * out_s] = CLIP(tmp[4] + t11);
-    out[ 5 * out_s] = CLIP(tmp[5] + t10a);
-    out[ 6 * out_s] = CLIP(tmp[6] + t9);
-    out[ 7 * out_s] = CLIP(tmp[7] + t8a);
-    out[ 8 * out_s] = CLIP(tmp[7] - t8a);
-    out[ 9 * out_s] = CLIP(tmp[6] - t9);
-    out[10 * out_s] = CLIP(tmp[5] - t10a);
-    out[11 * out_s] = CLIP(tmp[4] - t11);
-    out[12 * out_s] = CLIP(tmp[3] - t12);
-    out[13 * out_s] = CLIP(tmp[2] - t13a);
-    out[14 * out_s] = CLIP(tmp[1] - t14);
-    out[15 * out_s] = CLIP(tmp[0] - t15a);
+    const int t0 = c[ 0 * stride];
+    const int t1 = c[ 2 * stride];
+    const int t2 = c[ 4 * stride];
+    const int t3 = c[ 6 * stride];
+    const int t4 = c[ 8 * stride];
+    const int t5 = c[10 * stride];
+    const int t6 = c[12 * stride];
+    const int t7 = c[14 * stride];
+
+    c[ 0 * stride] = CLIP(t0 + t15a);
+    c[ 1 * stride] = CLIP(t1 + t14);
+    c[ 2 * stride] = CLIP(t2 + t13a);
+    c[ 3 * stride] = CLIP(t3 + t12);
+    c[ 4 * stride] = CLIP(t4 + t11);
+    c[ 5 * stride] = CLIP(t5 + t10a);
+    c[ 6 * stride] = CLIP(t6 + t9);
+    c[ 7 * stride] = CLIP(t7 + t8a);
+    c[ 8 * stride] = CLIP(t7 - t8a);
+    c[ 9 * stride] = CLIP(t6 - t9);
+    c[10 * stride] = CLIP(t5 - t10a);
+    c[11 * stride] = CLIP(t4 - t11);
+    c[12 * stride] = CLIP(t3 - t12);
+    c[13 * stride] = CLIP(t2 - t13a);
+    c[14 * stride] = CLIP(t1 - t14);
+    c[15 * stride] = CLIP(t0 - t15a);
 }
 
-void dav1d_inv_dct32_1d_c(const int32_t *const in, const ptrdiff_t in_s,
-                          int32_t *const out, const ptrdiff_t out_s,
-                          const int max)
+void dav1d_inv_dct16_1d_c(int32_t *const c, const ptrdiff_t stride,
+                          const int min, const int max)
 {
-    const int min = -max - 1;
-    int32_t tmp[16];
+    inv_dct16_1d_internal_c(c, stride, min, max, 0);
+}
 
-    dav1d_inv_dct16_1d_c(in, in_s * 2, tmp, 1, max);
+static NOINLINE void
+inv_dct32_1d_internal_c(int32_t *const c, const ptrdiff_t stride,
+                        const int min, const int max, const int tx64)
+{
+    assert(stride > 0);
+    inv_dct16_1d_internal_c(c, stride << 1, min, max, tx64);
 
-    const int in1  = in[ 1 * in_s], in3  = in[ 3 * in_s];
-    const int in5  = in[ 5 * in_s], in7  = in[ 7 * in_s];
-    const int in9  = in[ 9 * in_s], in11 = in[11 * in_s];
-    const int in13 = in[13 * in_s], in15 = in[15 * in_s];
-    const int in17 = in[17 * in_s], in19 = in[19 * in_s];
-    const int in21 = in[21 * in_s], in23 = in[23 * in_s];
-    const int in25 = in[25 * in_s], in27 = in[27 * in_s];
-    const int in29 = in[29 * in_s], in31 = in[31 * in_s];
+    const int in1  = c[ 1 * stride], in3  = c[ 3 * stride];
+    const int in5  = c[ 5 * stride], in7  = c[ 7 * stride];
+    const int in9  = c[ 9 * stride], in11 = c[11 * stride];
+    const int in13 = c[13 * stride], in15 = c[15 * stride];
 
-    int t16a = ((in1  *   201         - in31 * (4091 - 4096) + 2048) >> 12) - in31;
-    int t31a = ((in1  * (4091 - 4096) + in31 *   201         + 2048) >> 12) + in1;
-    int t17a = ((in17 * (3035 - 4096) - in15 *  2751         + 2048) >> 12) + in17;
-    int t30a = ((in17 *  2751         + in15 * (3035 - 4096) + 2048) >> 12) + in15;
-    int t18a = ((in9  *  1751         - in23 * (3703 - 4096) + 2048) >> 12) - in23;
-    int t29a = ((in9  * (3703 - 4096) + in23 *  1751         + 2048) >> 12) + in9;
-    int t19a = ((in25 * (3857 - 4096) - in7  *  1380         + 2048) >> 12) + in25;
-    int t28a = ((in25 *  1380         + in7  * (3857 - 4096) + 2048) >> 12) + in7;
-    int t20a = ((in5  *   995         - in27 * (3973 - 4096) + 2048) >> 12) - in27;
-    int t27a = ((in5  * (3973 - 4096) + in27 *   995         + 2048) >> 12) + in5;
-    int t21a = ((in21 * (3513 - 4096) - in11 *  2106         + 2048) >> 12) + in21;
-    int t26a = ((in21 *  2106         + in11 * (3513 - 4096) + 2048) >> 12) + in11;
-    int t22a =  (in13 *  1220         - in19 *  1645         + 1024) >> 11;
-    int t25a =  (in13 *  1645         + in19 *  1220         + 1024) >> 11;
-    int t23a = ((in29 * (4052 - 4096) - in3  *   601         + 2048) >> 12) + in29;
-    int t24a = ((in29 *   601         + in3  * (4052 - 4096) + 2048) >> 12) + in3;
+    int t16a, t17a, t18a, t19a, t20a, t21a, t22a, t23a;
+    int t24a, t25a, t26a, t27a, t28a, t29a, t30a, t31a;
+    if (tx64) {
+        t16a = (in1  *   201 + 2048) >> 12;
+        t17a = (in15 * -2751 + 2048) >> 12;
+        t18a = (in9  *  1751 + 2048) >> 12;
+        t19a = (in7  * -1380 + 2048) >> 12;
+        t20a = (in5  *   995 + 2048) >> 12;
+        t21a = (in11 * -2106 + 2048) >> 12;
+        t22a = (in13 *  2440 + 2048) >> 12;
+        t23a = (in3  *  -601 + 2048) >> 12;
+        t24a = (in3  *  4052 + 2048) >> 12;
+        t25a = (in13 *  3290 + 2048) >> 12;
+        t26a = (in11 *  3513 + 2048) >> 12;
+        t27a = (in5  *  3973 + 2048) >> 12;
+        t28a = (in7  *  3857 + 2048) >> 12;
+        t29a = (in9  *  3703 + 2048) >> 12;
+        t30a = (in15 *  3035 + 2048) >> 12;
+        t31a = (in1  *  4091 + 2048) >> 12;
+    } else {
+        const int in17 = c[17 * stride], in19 = c[19 * stride];
+        const int in21 = c[21 * stride], in23 = c[23 * stride];
+        const int in25 = c[25 * stride], in27 = c[27 * stride];
+        const int in29 = c[29 * stride], in31 = c[31 * stride];
 
+        t16a = ((in1  *   201         - in31 * (4091 - 4096) + 2048) >> 12) - in31;
+        t17a = ((in17 * (3035 - 4096) - in15 *  2751         + 2048) >> 12) + in17;
+        t18a = ((in9  *  1751         - in23 * (3703 - 4096) + 2048) >> 12) - in23;
+        t19a = ((in25 * (3857 - 4096) - in7  *  1380         + 2048) >> 12) + in25;
+        t20a = ((in5  *   995         - in27 * (3973 - 4096) + 2048) >> 12) - in27;
+        t21a = ((in21 * (3513 - 4096) - in11 *  2106         + 2048) >> 12) + in21;
+        t22a =  (in13 *  1220         - in19 *  1645         + 1024) >> 11;
+        t23a = ((in29 * (4052 - 4096) - in3  *   601         + 2048) >> 12) + in29;
+        t24a = ((in29 *   601         + in3  * (4052 - 4096) + 2048) >> 12) + in3;
+        t25a =  (in13 *  1645         + in19 *  1220         + 1024) >> 11;
+        t26a = ((in21 *  2106         + in11 * (3513 - 4096) + 2048) >> 12) + in11;
+        t27a = ((in5  * (3973 - 4096) + in27 *   995         + 2048) >> 12) + in5;
+        t28a = ((in25 *  1380         + in7  * (3857 - 4096) + 2048) >> 12) + in7;
+        t29a = ((in9  * (3703 - 4096) + in23 *  1751         + 2048) >> 12) + in9;
+        t30a = ((in17 *  2751         + in15 * (3035 - 4096) + 2048) >> 12) + in15;
+        t31a = ((in1  * (4091 - 4096) + in31 *   201         + 2048) >> 12) + in1;
+    }
+
     int t16 = CLIP(t16a + t17a);
     int t17 = CLIP(t16a - t17a);
     int t18 = CLIP(t19a - t18a);
@@ -298,98 +376,110 @@
     t23a = ((t24  - t23 ) * 181 + 128) >> 8;
     t24a = ((t24  + t23 ) * 181 + 128) >> 8;
 
-    out[ 0 * out_s] = CLIP(tmp[ 0] + t31);
-    out[ 1 * out_s] = CLIP(tmp[ 1] + t30a);
-    out[ 2 * out_s] = CLIP(tmp[ 2] + t29);
-    out[ 3 * out_s] = CLIP(tmp[ 3] + t28a);
-    out[ 4 * out_s] = CLIP(tmp[ 4] + t27);
-    out[ 5 * out_s] = CLIP(tmp[ 5] + t26a);
-    out[ 6 * out_s] = CLIP(tmp[ 6] + t25);
-    out[ 7 * out_s] = CLIP(tmp[ 7] + t24a);
-    out[ 8 * out_s] = CLIP(tmp[ 8] + t23a);
-    out[ 9 * out_s] = CLIP(tmp[ 9] + t22);
-    out[10 * out_s] = CLIP(tmp[10] + t21a);
-    out[11 * out_s] = CLIP(tmp[11] + t20);
-    out[12 * out_s] = CLIP(tmp[12] + t19a);
-    out[13 * out_s] = CLIP(tmp[13] + t18);
-    out[14 * out_s] = CLIP(tmp[14] + t17a);
-    out[15 * out_s] = CLIP(tmp[15] + t16);
-    out[16 * out_s] = CLIP(tmp[15] - t16);
-    out[17 * out_s] = CLIP(tmp[14] - t17a);
-    out[18 * out_s] = CLIP(tmp[13] - t18);
-    out[19 * out_s] = CLIP(tmp[12] - t19a);
-    out[20 * out_s] = CLIP(tmp[11] - t20);
-    out[21 * out_s] = CLIP(tmp[10] - t21a);
-    out[22 * out_s] = CLIP(tmp[ 9] - t22);
-    out[23 * out_s] = CLIP(tmp[ 8] - t23a);
-    out[24 * out_s] = CLIP(tmp[ 7] - t24a);
-    out[25 * out_s] = CLIP(tmp[ 6] - t25);
-    out[26 * out_s] = CLIP(tmp[ 5] - t26a);
-    out[27 * out_s] = CLIP(tmp[ 4] - t27);
-    out[28 * out_s] = CLIP(tmp[ 3] - t28a);
-    out[29 * out_s] = CLIP(tmp[ 2] - t29);
-    out[30 * out_s] = CLIP(tmp[ 1] - t30a);
-    out[31 * out_s] = CLIP(tmp[ 0] - t31);
+    const int t0  = c[ 0 * stride];
+    const int t1  = c[ 2 * stride];
+    const int t2  = c[ 4 * stride];
+    const int t3  = c[ 6 * stride];
+    const int t4  = c[ 8 * stride];
+    const int t5  = c[10 * stride];
+    const int t6  = c[12 * stride];
+    const int t7  = c[14 * stride];
+    const int t8  = c[16 * stride];
+    const int t9  = c[18 * stride];
+    const int t10 = c[20 * stride];
+    const int t11 = c[22 * stride];
+    const int t12 = c[24 * stride];
+    const int t13 = c[26 * stride];
+    const int t14 = c[28 * stride];
+    const int t15 = c[30 * stride];
+
+    c[ 0 * stride] = CLIP(t0  + t31);
+    c[ 1 * stride] = CLIP(t1  + t30a);
+    c[ 2 * stride] = CLIP(t2  + t29);
+    c[ 3 * stride] = CLIP(t3  + t28a);
+    c[ 4 * stride] = CLIP(t4  + t27);
+    c[ 5 * stride] = CLIP(t5  + t26a);
+    c[ 6 * stride] = CLIP(t6  + t25);
+    c[ 7 * stride] = CLIP(t7  + t24a);
+    c[ 8 * stride] = CLIP(t8  + t23a);
+    c[ 9 * stride] = CLIP(t9  + t22);
+    c[10 * stride] = CLIP(t10 + t21a);
+    c[11 * stride] = CLIP(t11 + t20);
+    c[12 * stride] = CLIP(t12 + t19a);
+    c[13 * stride] = CLIP(t13 + t18);
+    c[14 * stride] = CLIP(t14 + t17a);
+    c[15 * stride] = CLIP(t15 + t16);
+    c[16 * stride] = CLIP(t15 - t16);
+    c[17 * stride] = CLIP(t14 - t17a);
+    c[18 * stride] = CLIP(t13 - t18);
+    c[19 * stride] = CLIP(t12 - t19a);
+    c[20 * stride] = CLIP(t11 - t20);
+    c[21 * stride] = CLIP(t10 - t21a);
+    c[22 * stride] = CLIP(t9  - t22);
+    c[23 * stride] = CLIP(t8  - t23a);
+    c[24 * stride] = CLIP(t7  - t24a);
+    c[25 * stride] = CLIP(t6  - t25);
+    c[26 * stride] = CLIP(t5  - t26a);
+    c[27 * stride] = CLIP(t4  - t27);
+    c[28 * stride] = CLIP(t3  - t28a);
+    c[29 * stride] = CLIP(t2  - t29);
+    c[30 * stride] = CLIP(t1  - t30a);
+    c[31 * stride] = CLIP(t0  - t31);
 }
 
-void dav1d_inv_dct64_1d_c(const int32_t *const in, const ptrdiff_t in_s,
-                          int32_t *const out, const ptrdiff_t out_s,
-                          const int max)
+void dav1d_inv_dct32_1d_c(int32_t *const c, const ptrdiff_t stride,
+                          const int min, const int max)
 {
-    const int min = -max - 1;
-    int32_t tmp[32];
+    inv_dct32_1d_internal_c(c, stride, min, max, 0);
+}
 
-    dav1d_inv_dct32_1d_c(in, in_s * 2, tmp, 1, max);
+void dav1d_inv_dct64_1d_c(int32_t *const c, const ptrdiff_t stride,
+                          const int min, const int max)
+{
+    assert(stride > 0);
+    inv_dct32_1d_internal_c(c, stride << 1, min, max, 1);
 
-    const int in1  = in[ 1 * in_s], in3  = in[ 3 * in_s];
-    const int in5  = in[ 5 * in_s], in7  = in[ 7 * in_s];
-    const int in9  = in[ 9 * in_s], in11 = in[11 * in_s];
-    const int in13 = in[13 * in_s], in15 = in[15 * in_s];
-    const int in17 = in[17 * in_s], in19 = in[19 * in_s];
-    const int in21 = in[21 * in_s], in23 = in[23 * in_s];
-    const int in25 = in[25 * in_s], in27 = in[27 * in_s];
-    const int in29 = in[29 * in_s], in31 = in[31 * in_s];
-    const int in33 = in[33 * in_s], in35 = in[35 * in_s];
-    const int in37 = in[37 * in_s], in39 = in[39 * in_s];
-    const int in41 = in[41 * in_s], in43 = in[43 * in_s];
-    const int in45 = in[45 * in_s], in47 = in[47 * in_s];
-    const int in49 = in[49 * in_s], in51 = in[51 * in_s];
-    const int in53 = in[53 * in_s], in55 = in[55 * in_s];
-    const int in57 = in[57 * in_s], in59 = in[59 * in_s];
-    const int in61 = in[61 * in_s], in63 = in[63 * in_s];
+    const int in1  = c[ 1 * stride], in3  = c[ 3 * stride];
+    const int in5  = c[ 5 * stride], in7  = c[ 7 * stride];
+    const int in9  = c[ 9 * stride], in11 = c[11 * stride];
+    const int in13 = c[13 * stride], in15 = c[15 * stride];
+    const int in17 = c[17 * stride], in19 = c[19 * stride];
+    const int in21 = c[21 * stride], in23 = c[23 * stride];
+    const int in25 = c[25 * stride], in27 = c[27 * stride];
+    const int in29 = c[29 * stride], in31 = c[31 * stride];
 
-    int t32a = ((in1  *   101         - in63 * (4095 - 4096) + 2048) >> 12) - in63;
-    int t33a = ((in33 * (2967 - 4096) - in31 *  2824         + 2048) >> 12) + in33;
-    int t34a = ((in17 *  1660         - in47 * (3745 - 4096) + 2048) >> 12) - in47;
-    int t35a =  (in49 *  1911         - in15 *   737         + 1024) >> 11;
-    int t36a = ((in9  *   897         - in55 * (3996 - 4096) + 2048) >> 12) - in55;
-    int t37a = ((in41 * (3461 - 4096) - in23 *  2191         + 2048) >> 12) + in41;
-    int t38a = ((in25 *  2359         - in39 * (3349 - 4096) + 2048) >> 12) - in39;
-    int t39a =  (in57 *  2018         - in7  *   350         + 1024) >> 11;
-    int t40a = ((in5  *   501         - in59 * (4065 - 4096) + 2048) >> 12) - in59;
-    int t41a = ((in37 * (3229 - 4096) - in27 *  2520         + 2048) >> 12) + in37;
-    int t42a = ((in21 *  2019         - in43 * (3564 - 4096) + 2048) >> 12) - in43;
-    int t43a =  (in53 *  1974         - in11 *   546         + 1024) >> 11;
-    int t44a = ((in13 *  1285         - in51 * (3889 - 4096) + 2048) >> 12) - in51;
-    int t45a = ((in45 * (3659 - 4096) - in19 *  1842         + 2048) >> 12) + in45;
-    int t46a = ((in29 *  2675         - in35 * (3102 - 4096) + 2048) >> 12) - in35;
-    int t47a = ((in61 * (4085 - 4096) - in3  *   301         + 2048) >> 12) + in61;
-    int t48a = ((in61 *   301         + in3  * (4085 - 4096) + 2048) >> 12) + in3;
-    int t49a = ((in29 * (3102 - 4096) + in35 *  2675         + 2048) >> 12) + in29;
-    int t50a = ((in45 *  1842         + in19 * (3659 - 4096) + 2048) >> 12) + in19;
-    int t51a = ((in13 * (3889 - 4096) + in51 *  1285         + 2048) >> 12) + in13;
-    int t52a =  (in53 *   546         + in11 *  1974         + 1024) >> 11;
-    int t53a = ((in21 * (3564 - 4096) + in43 *  2019         + 2048) >> 12) + in21;
-    int t54a = ((in37 *  2520         + in27 * (3229 - 4096) + 2048) >> 12) + in27;
-    int t55a = ((in5  * (4065 - 4096) + in59 *   501         + 2048) >> 12) + in5;
-    int t56a =  (in57 *   350         + in7  *  2018         + 1024) >> 11;
-    int t57a = ((in25 * (3349 - 4096) + in39 *  2359         + 2048) >> 12) + in25;
-    int t58a = ((in41 *  2191         + in23 * (3461 - 4096) + 2048) >> 12) + in23;
-    int t59a = ((in9  * (3996 - 4096) + in55 *   897         + 2048) >> 12) + in9;
-    int t60a =  (in49 *   737         + in15 *  1911         + 1024) >> 11;
-    int t61a = ((in17 * (3745 - 4096) + in47 *  1660         + 2048) >> 12) + in17;
-    int t62a = ((in33 *  2824         + in31 * (2967 - 4096) + 2048) >> 12) + in31;
-    int t63a = ((in1  * (4095 - 4096) + in63 *   101         + 2048) >> 12) + in1;
+    int t32a = (in1  *   101 + 2048) >> 12;
+    int t33a = (in31 * -2824 + 2048) >> 12;
+    int t34a = (in17 *  1660 + 2048) >> 12;
+    int t35a = (in15 * -1474 + 2048) >> 12;
+    int t36a = (in9  *   897 + 2048) >> 12;
+    int t37a = (in23 * -2191 + 2048) >> 12;
+    int t38a = (in25 *  2359 + 2048) >> 12;
+    int t39a = (in7  *  -700 + 2048) >> 12;
+    int t40a = (in5  *   501 + 2048) >> 12;
+    int t41a = (in27 * -2520 + 2048) >> 12;
+    int t42a = (in21 *  2019 + 2048) >> 12;
+    int t43a = (in11 * -1092 + 2048) >> 12;
+    int t44a = (in13 *  1285 + 2048) >> 12;
+    int t45a = (in19 * -1842 + 2048) >> 12;
+    int t46a = (in29 *  2675 + 2048) >> 12;
+    int t47a = (in3  *  -301 + 2048) >> 12;
+    int t48a = (in3  *  4085 + 2048) >> 12;
+    int t49a = (in29 *  3102 + 2048) >> 12;
+    int t50a = (in19 *  3659 + 2048) >> 12;
+    int t51a = (in13 *  3889 + 2048) >> 12;
+    int t52a = (in11 *  3948 + 2048) >> 12;
+    int t53a = (in21 *  3564 + 2048) >> 12;
+    int t54a = (in27 *  3229 + 2048) >> 12;
+    int t55a = (in5  *  4065 + 2048) >> 12;
+    int t56a = (in7  *  4036 + 2048) >> 12;
+    int t57a = (in25 *  3349 + 2048) >> 12;
+    int t58a = (in23 *  3461 + 2048) >> 12;
+    int t59a = (in9  *  3996 + 2048) >> 12;
+    int t60a = (in15 *  3822 + 2048) >> 12;
+    int t61a = (in17 *  3745 + 2048) >> 12;
+    int t62a = (in31 *  2967 + 2048) >> 12;
+    int t63a = (in1  *  4095 + 2048) >> 12;
 
     int t32 = CLIP(t32a + t33a);
     int t33 = CLIP(t32a - t33a);
@@ -591,76 +681,111 @@
     t54  = ((t41a + t54a) * 181 + 128) >> 8;
     t55a = ((t40  + t55 ) * 181 + 128) >> 8;
 
-    out[ 0 * out_s] = CLIP(tmp[ 0] + t63a);
-    out[ 1 * out_s] = CLIP(tmp[ 1] + t62);
-    out[ 2 * out_s] = CLIP(tmp[ 2] + t61a);
-    out[ 3 * out_s] = CLIP(tmp[ 3] + t60);
-    out[ 4 * out_s] = CLIP(tmp[ 4] + t59a);
-    out[ 5 * out_s] = CLIP(tmp[ 5] + t58);
-    out[ 6 * out_s] = CLIP(tmp[ 6] + t57a);
-    out[ 7 * out_s] = CLIP(tmp[ 7] + t56);
-    out[ 8 * out_s] = CLIP(tmp[ 8] + t55a);
-    out[ 9 * out_s] = CLIP(tmp[ 9] + t54);
-    out[10 * out_s] = CLIP(tmp[10] + t53a);
-    out[11 * out_s] = CLIP(tmp[11] + t52);
-    out[12 * out_s] = CLIP(tmp[12] + t51a);
-    out[13 * out_s] = CLIP(tmp[13] + t50);
-    out[14 * out_s] = CLIP(tmp[14] + t49a);
-    out[15 * out_s] = CLIP(tmp[15] + t48);
-    out[16 * out_s] = CLIP(tmp[16] + t47);
-    out[17 * out_s] = CLIP(tmp[17] + t46a);
-    out[18 * out_s] = CLIP(tmp[18] + t45);
-    out[19 * out_s] = CLIP(tmp[19] + t44a);
-    out[20 * out_s] = CLIP(tmp[20] + t43);
-    out[21 * out_s] = CLIP(tmp[21] + t42a);
-    out[22 * out_s] = CLIP(tmp[22] + t41);
-    out[23 * out_s] = CLIP(tmp[23] + t40a);
-    out[24 * out_s] = CLIP(tmp[24] + t39);
-    out[25 * out_s] = CLIP(tmp[25] + t38a);
-    out[26 * out_s] = CLIP(tmp[26] + t37);
-    out[27 * out_s] = CLIP(tmp[27] + t36a);
-    out[28 * out_s] = CLIP(tmp[28] + t35);
-    out[29 * out_s] = CLIP(tmp[29] + t34a);
-    out[30 * out_s] = CLIP(tmp[30] + t33);
-    out[31 * out_s] = CLIP(tmp[31] + t32a);
-    out[32 * out_s] = CLIP(tmp[31] - t32a);
-    out[33 * out_s] = CLIP(tmp[30] - t33);
-    out[34 * out_s] = CLIP(tmp[29] - t34a);
-    out[35 * out_s] = CLIP(tmp[28] - t35);
-    out[36 * out_s] = CLIP(tmp[27] - t36a);
-    out[37 * out_s] = CLIP(tmp[26] - t37);
-    out[38 * out_s] = CLIP(tmp[25] - t38a);
-    out[39 * out_s] = CLIP(tmp[24] - t39);
-    out[40 * out_s] = CLIP(tmp[23] - t40a);
-    out[41 * out_s] = CLIP(tmp[22] - t41);
-    out[42 * out_s] = CLIP(tmp[21] - t42a);
-    out[43 * out_s] = CLIP(tmp[20] - t43);
-    out[44 * out_s] = CLIP(tmp[19] - t44a);
-    out[45 * out_s] = CLIP(tmp[18] - t45);
-    out[46 * out_s] = CLIP(tmp[17] - t46a);
-    out[47 * out_s] = CLIP(tmp[16] - t47);
-    out[48 * out_s] = CLIP(tmp[15] - t48);
-    out[49 * out_s] = CLIP(tmp[14] - t49a);
-    out[50 * out_s] = CLIP(tmp[13] - t50);
-    out[51 * out_s] = CLIP(tmp[12] - t51a);
-    out[52 * out_s] = CLIP(tmp[11] - t52);
-    out[53 * out_s] = CLIP(tmp[10] - t53a);
-    out[54 * out_s] = CLIP(tmp[ 9] - t54);
-    out[55 * out_s] = CLIP(tmp[ 8] - t55a);
-    out[56 * out_s] = CLIP(tmp[ 7] - t56);
-    out[57 * out_s] = CLIP(tmp[ 6] - t57a);
-    out[58 * out_s] = CLIP(tmp[ 5] - t58);
-    out[59 * out_s] = CLIP(tmp[ 4] - t59a);
-    out[60 * out_s] = CLIP(tmp[ 3] - t60);
-    out[61 * out_s] = CLIP(tmp[ 2] - t61a);
-    out[62 * out_s] = CLIP(tmp[ 1] - t62);
-    out[63 * out_s] = CLIP(tmp[ 0] - t63a);
+    const int t0  = c[ 0 * stride];
+    const int t1  = c[ 2 * stride];
+    const int t2  = c[ 4 * stride];
+    const int t3  = c[ 6 * stride];
+    const int t4  = c[ 8 * stride];
+    const int t5  = c[10 * stride];
+    const int t6  = c[12 * stride];
+    const int t7  = c[14 * stride];
+    const int t8  = c[16 * stride];
+    const int t9  = c[18 * stride];
+    const int t10 = c[20 * stride];
+    const int t11 = c[22 * stride];
+    const int t12 = c[24 * stride];
+    const int t13 = c[26 * stride];
+    const int t14 = c[28 * stride];
+    const int t15 = c[30 * stride];
+    const int t16 = c[32 * stride];
+    const int t17 = c[34 * stride];
+    const int t18 = c[36 * stride];
+    const int t19 = c[38 * stride];
+    const int t20 = c[40 * stride];
+    const int t21 = c[42 * stride];
+    const int t22 = c[44 * stride];
+    const int t23 = c[46 * stride];
+    const int t24 = c[48 * stride];
+    const int t25 = c[50 * stride];
+    const int t26 = c[52 * stride];
+    const int t27 = c[54 * stride];
+    const int t28 = c[56 * stride];
+    const int t29 = c[58 * stride];
+    const int t30 = c[60 * stride];
+    const int t31 = c[62 * stride];
+
+    c[ 0 * stride] = CLIP(t0  + t63a);
+    c[ 1 * stride] = CLIP(t1  + t62);
+    c[ 2 * stride] = CLIP(t2  + t61a);
+    c[ 3 * stride] = CLIP(t3  + t60);
+    c[ 4 * stride] = CLIP(t4  + t59a);
+    c[ 5 * stride] = CLIP(t5  + t58);
+    c[ 6 * stride] = CLIP(t6  + t57a);
+    c[ 7 * stride] = CLIP(t7  + t56);
+    c[ 8 * stride] = CLIP(t8  + t55a);
+    c[ 9 * stride] = CLIP(t9  + t54);
+    c[10 * stride] = CLIP(t10 + t53a);
+    c[11 * stride] = CLIP(t11 + t52);
+    c[12 * stride] = CLIP(t12 + t51a);
+    c[13 * stride] = CLIP(t13 + t50);
+    c[14 * stride] = CLIP(t14 + t49a);
+    c[15 * stride] = CLIP(t15 + t48);
+    c[16 * stride] = CLIP(t16 + t47);
+    c[17 * stride] = CLIP(t17 + t46a);
+    c[18 * stride] = CLIP(t18 + t45);
+    c[19 * stride] = CLIP(t19 + t44a);
+    c[20 * stride] = CLIP(t20 + t43);
+    c[21 * stride] = CLIP(t21 + t42a);
+    c[22 * stride] = CLIP(t22 + t41);
+    c[23 * stride] = CLIP(t23 + t40a);
+    c[24 * stride] = CLIP(t24 + t39);
+    c[25 * stride] = CLIP(t25 + t38a);
+    c[26 * stride] = CLIP(t26 + t37);
+    c[27 * stride] = CLIP(t27 + t36a);
+    c[28 * stride] = CLIP(t28 + t35);
+    c[29 * stride] = CLIP(t29 + t34a);
+    c[30 * stride] = CLIP(t30 + t33);
+    c[31 * stride] = CLIP(t31 + t32a);
+    c[32 * stride] = CLIP(t31 - t32a);
+    c[33 * stride] = CLIP(t30 - t33);
+    c[34 * stride] = CLIP(t29 - t34a);
+    c[35 * stride] = CLIP(t28 - t35);
+    c[36 * stride] = CLIP(t27 - t36a);
+    c[37 * stride] = CLIP(t26 - t37);
+    c[38 * stride] = CLIP(t25 - t38a);
+    c[39 * stride] = CLIP(t24 - t39);
+    c[40 * stride] = CLIP(t23 - t40a);
+    c[41 * stride] = CLIP(t22 - t41);
+    c[42 * stride] = CLIP(t21 - t42a);
+    c[43 * stride] = CLIP(t20 - t43);
+    c[44 * stride] = CLIP(t19 - t44a);
+    c[45 * stride] = CLIP(t18 - t45);
+    c[46 * stride] = CLIP(t17 - t46a);
+    c[47 * stride] = CLIP(t16 - t47);
+    c[48 * stride] = CLIP(t15 - t48);
+    c[49 * stride] = CLIP(t14 - t49a);
+    c[50 * stride] = CLIP(t13 - t50);
+    c[51 * stride] = CLIP(t12 - t51a);
+    c[52 * stride] = CLIP(t11 - t52);
+    c[53 * stride] = CLIP(t10 - t53a);
+    c[54 * stride] = CLIP(t9  - t54);
+    c[55 * stride] = CLIP(t8  - t55a);
+    c[56 * stride] = CLIP(t7  - t56);
+    c[57 * stride] = CLIP(t6  - t57a);
+    c[58 * stride] = CLIP(t5  - t58);
+    c[59 * stride] = CLIP(t4  - t59a);
+    c[60 * stride] = CLIP(t3  - t60);
+    c[61 * stride] = CLIP(t2  - t61a);
+    c[62 * stride] = CLIP(t1  - t62);
+    c[63 * stride] = CLIP(t0  - t63a);
 }
 
-void dav1d_inv_adst4_1d_c(const int32_t *const in, const ptrdiff_t in_s,
-                          int32_t *const out, const ptrdiff_t out_s,
-                          const int range)
+static NOINLINE void
+inv_adst4_1d_internal_c(const int32_t *const in, const ptrdiff_t in_s,
+                        const int min, const int max,
+                        int32_t *const out, const ptrdiff_t out_s)
 {
+    assert(in_s > 0 && out_s != 0);
     const int in0 = in[0 * in_s], in1 = in[1 * in_s];
     const int in2 = in[2 * in_s], in3 = in[3 * in_s];
 
@@ -676,11 +801,12 @@
                      in0 + in2 - in1;
 }
 
-void dav1d_inv_adst8_1d_c(const int32_t *const in, const ptrdiff_t in_s,
-                          int32_t *const out, const ptrdiff_t out_s,
-                          const int max)
+static NOINLINE void
+inv_adst8_1d_internal_c(const int32_t *const in, const ptrdiff_t in_s,
+                        const int min, const int max,
+                        int32_t *const out, const ptrdiff_t out_s)
 {
-    const int min = -max - 1;
+    assert(in_s > 0 && out_s != 0);
     const int in0 = in[0 * in_s], in1 = in[1 * in_s];
     const int in2 = in[2 * in_s], in3 = in[3 * in_s];
     const int in4 = in[4 * in_s], in5 = in[5 * in_s];
@@ -724,11 +850,12 @@
     out[5 * out_s] = -(((t6 - t7) * 181 + 128) >> 8);
 }
 
-void dav1d_inv_adst16_1d_c(const int32_t *const in, const ptrdiff_t in_s,
-                           int32_t *const out, const ptrdiff_t out_s,
-                           const int max)
+static NOINLINE void
+inv_adst16_1d_internal_c(const int32_t *const in, const ptrdiff_t in_s,
+                         const int min, const int max,
+                         int32_t *const out, const ptrdiff_t out_s)
 {
-    const int min = -max - 1;
+    assert(in_s > 0 && out_s != 0);
     const int in0  = in[ 0 * in_s], in1  = in[ 1 * in_s];
     const int in2  = in[ 2 * in_s], in3  = in[ 3 * in_s];
     const int in4  = in[ 4 * in_s], in5  = in[ 5 * in_s];
@@ -834,59 +961,66 @@
     out[10 * out_s] =   ((t14a - t15a) * 181 + 128) >> 8;
 }
 
-#define flip_inv_adst(sz) \
-void dav1d_inv_flipadst##sz##_1d_c(const int32_t *const in, const ptrdiff_t in_s, \
-                                   int32_t *const out, const ptrdiff_t out_s, \
-                                   const int range) \
+#define inv_adst_1d(sz) \
+void dav1d_inv_adst##sz##_1d_c(int32_t *const c, const ptrdiff_t stride, \
+                               const int min, const int max) \
 { \
-    dav1d_inv_adst##sz##_1d_c(in, in_s, &out[(sz - 1) * out_s], -out_s, range); \
+    inv_adst##sz##_1d_internal_c(c, stride, min, max, c, stride); \
+} \
+void dav1d_inv_flipadst##sz##_1d_c(int32_t *const c, const ptrdiff_t stride, \
+                                   const int min, const int max) \
+{ \
+    inv_adst##sz##_1d_internal_c(c, stride, min, max, \
+                                 &c[(sz - 1) * stride], -stride); \
 }
 
-flip_inv_adst(4)
-flip_inv_adst(8)
-flip_inv_adst(16)
+inv_adst_1d( 4)
+inv_adst_1d( 8)
+inv_adst_1d(16)
 
-#undef flip_inv_adst
+#undef inv_adst_1d
 
-void dav1d_inv_identity4_1d_c(const int32_t *const in, const ptrdiff_t in_s,
-                              int32_t *const out, const ptrdiff_t out_s,
-                              const int range)
+void dav1d_inv_identity4_1d_c(int32_t *const c, const ptrdiff_t stride,
+                              const int min, const int max)
 {
-    for (int i = 0; i < 4; i++)
-        out[out_s * i] = in[in_s * i] + ((in[in_s * i] * 1697 + 2048) >> 12);
+    assert(stride > 0);
+    for (int i = 0; i < 4; i++) {
+        const int in = c[stride * i];
+        c[stride * i] = in + ((in * 1697 + 2048) >> 12);
+    }
 }
 
-void dav1d_inv_identity8_1d_c(const int32_t *const in, const ptrdiff_t in_s,
-                              int32_t *const out, const ptrdiff_t out_s,
-                              const int range)
+void dav1d_inv_identity8_1d_c(int32_t *const c, const ptrdiff_t stride,
+                              const int min, const int max)
 {
+    assert(stride > 0);
     for (int i = 0; i < 8; i++)
-        out[out_s * i] = in[in_s * i] * 2;
+        c[stride * i] *= 2;
 }
 
-void dav1d_inv_identity16_1d_c(const int32_t *const in, const ptrdiff_t in_s,
-                               int32_t *const out, const ptrdiff_t out_s,
-                               const int range)
+void dav1d_inv_identity16_1d_c(int32_t *const c, const ptrdiff_t stride,
+                               const int min, const int max)
 {
-    for (int i = 0; i < 16; i++)
-        out[out_s * i] = 2 * in[in_s * i] + ((in[in_s * i] * 1697 + 1024) >> 11);
+    assert(stride > 0);
+    for (int i = 0; i < 16; i++) {
+        const int in = c[stride * i];
+        c[stride * i] = 2 * in + ((in * 1697 + 1024) >> 11);
+    }
 }
 
-void dav1d_inv_identity32_1d_c(const int32_t *const in, const ptrdiff_t in_s,
-                               int32_t *const out, const ptrdiff_t out_s,
-                               const int range)
+void dav1d_inv_identity32_1d_c(int32_t *const c, const ptrdiff_t stride,
+                               const int min, const int max)
 {
+    assert(stride > 0);
     for (int i = 0; i < 32; i++)
-        out[out_s * i] = in[in_s * i] * 4;
+        c[stride * i] *= 4;
 }
 
-void dav1d_inv_wht4_1d_c(const int32_t *const in, const ptrdiff_t in_s,
-                         int32_t *const out, const ptrdiff_t out_s,
-                         const int pass)
-{
-    const int sh = 2 * !pass;
-    const int in0 = in[0 * in_s] >> sh, in1 = in[1 * in_s] >> sh;
-    const int in2 = in[2 * in_s] >> sh, in3 = in[3 * in_s] >> sh;
+void dav1d_inv_wht4_1d_c(int32_t *const c, const ptrdiff_t stride) {
+    assert(stride > 0);
+    const int in0 = c[0 * stride], in1 = c[1 * stride];
+    const int in2 = c[2 * stride], in3 = c[3 * stride];
+
     const int t0 = in0 + in1;
     const int t2 = in2 - in3;
     const int t4 = (t0 - t2) >> 1;
@@ -893,8 +1027,8 @@
     const int t3 = t4 - in3;
     const int t1 = t4 - in1;
 
-    out[0 * out_s] = t0 - t3;
-    out[1 * out_s] = t3;
-    out[2 * out_s] = t1;
-    out[3 * out_s] = t2 + t1;
+    c[0 * stride] = t0 - t3;
+    c[1 * stride] = t3;
+    c[2 * stride] = t1;
+    c[3 * stride] = t2 + t1;
 }
--- a/src/itx_1d.h
+++ b/src/itx_1d.h
@@ -32,8 +32,7 @@
 #define DAV1D_SRC_ITX_1D_H
 
 #define decl_itx_1d_fn(name) \
-void (name)(const int32_t *in, ptrdiff_t in_s, \
-            int32_t *out, ptrdiff_t out_s, const int range)
+void (name)(int32_t *c, ptrdiff_t stride, int min, int max)
 typedef decl_itx_1d_fn(*itx_1d_fn);
 
 decl_itx_1d_fn(dav1d_inv_dct4_1d_c);
@@ -55,6 +54,6 @@
 decl_itx_1d_fn(dav1d_inv_identity16_1d_c);
 decl_itx_1d_fn(dav1d_inv_identity32_1d_c);
 
-decl_itx_1d_fn(dav1d_inv_wht4_1d_c);
+void dav1d_inv_wht4_1d_c(int32_t *c, ptrdiff_t stride);
 
 #endif /* DAV1D_SRC_ITX_1D_H */
--- a/src/itx_tmpl.c
+++ b/src/itx_tmpl.c
@@ -37,64 +37,66 @@
 #include "src/itx.h"
 #include "src/itx_1d.h"
 
-static void NOINLINE
-inv_txfm_add_c(pixel *dst, const ptrdiff_t stride,
-               coef *const coeff, const int eob,
-               const int w, const int h, const int shift,
+static NOINLINE void
+inv_txfm_add_c(pixel *dst, const ptrdiff_t stride, coef *const coeff,
+               const int eob, const int w, const int h, const int shift,
                const itx_1d_fn first_1d_fn, const itx_1d_fn second_1d_fn,
                const int has_dconly HIGHBD_DECL_SUFFIX)
 {
-    int i, j;
-    assert((h >= 4 && h <= 64) && (w >= 4 && w <= 64));
+    assert(w >= 4 && w <= 64);
+    assert(h >= 4 && h <= 64);
+    assert(eob >= 0);
+
     const int is_rect2 = w * 2 == h || h * 2 == w;
-    const int bitdepth = bitdepth_from_max(bitdepth_max);
     const int rnd = (1 << shift) >> 1;
 
-    if (has_dconly && eob == 0) {
+    if (eob < has_dconly) {
         int dc = coeff[0];
         coeff[0] = 0;
         if (is_rect2)
-            dc = (dc * 2896 + 2048) >> 12;
-        dc = (dc * 2896 + 2048) >> 12;
+            dc = (dc * 181 + 128) >> 8;
+        dc = (dc * 181 + 128) >> 8;
         dc = (dc + rnd) >> shift;
-        dc = (dc * 2896 + 2048) >> 12;
-        dc = (dc + 8) >> 4;
-        for (j = 0; j < h; j++)
-            for (i = 0; i < w; i++)
-                dst[i + j * PXSTRIDE(stride)] =
-                    iclip_pixel(dst[i + j * PXSTRIDE(stride)] + dc);
+        dc = (dc * 181 + 128 + 2048) >> 12;
+        for (int y = 0; y < h; y++, dst += PXSTRIDE(stride))
+            for (int x = 0; x < w; x++)
+                dst[x] = iclip_pixel(dst[x] + dc);
         return;
     }
-    assert(eob > 0 || (eob == 0 && !has_dconly));
 
-    const ptrdiff_t sh = imin(h, 32), sw = imin(w, 32);
-    // Maximum value for h and w is 64
-    int32_t tmp[4096 /* w * h */], out[64 /* h */], in_mem[64 /* w */];
-    const int row_clip_max = (1 << (bitdepth + 8 - 1)) - 1;
-    const int col_clip_max = (1 << (imax(bitdepth + 6, 16) - 1)) - 1;
+    const int sh = imin(h, 32), sw = imin(w, 32);
+#if BITDEPTH == 8
+    const int row_clip_min = INT16_MIN;
+    const int col_clip_min = INT16_MIN;
+#else
+    const int row_clip_min = (int) ((unsigned) ~bitdepth_max << 7);
+    const int col_clip_min = (int) ((unsigned) ~bitdepth_max << 5);
+#endif
+    const int row_clip_max = ~row_clip_min;
+    const int col_clip_max = ~col_clip_min;
 
-    if (w != sw) memset(&in_mem[sw], 0, (w - sw) * sizeof(*in_mem));
-    for (i = 0; i < sh; i++) {
-        for (j = 0; j < sw; j++) {
-            in_mem[j] = coeff[i + j * sh];
-            if (is_rect2)
-                in_mem[j] = (in_mem[j] * 2896 + 2048) >> 12;
-        }
-        first_1d_fn(in_mem, 1, &tmp[i * w], 1, row_clip_max);
-        for (j = 0; j < w; j++)
-            tmp[i * w + j] = iclip((tmp[i * w + j] + rnd) >> shift,
-                                   -col_clip_max - 1, col_clip_max);
+    int32_t tmp[64 * 64], *c = tmp;
+    for (int y = 0; y < sh; y++, c += w) {
+        if (is_rect2)
+            for (int x = 0; x < sw; x++)
+                c[x] = (coeff[y + x * sh] * 181 + 128) >> 8;
+        else
+            for (int x = 0; x < sw; x++)
+                c[x] = coeff[y + x * sh];
+        first_1d_fn(c, 1, row_clip_min, row_clip_max);
     }
 
-    if (h != sh) memset(&tmp[sh * w], 0, w * (h - sh) * sizeof(*tmp));
-    for (i = 0; i < w; i++) {
-        second_1d_fn(&tmp[i], w, out, 1, col_clip_max);
-        for (j = 0; j < h; j++)
-            dst[i + j * PXSTRIDE(stride)] =
-                iclip_pixel(dst[i + j * PXSTRIDE(stride)] +
-                            ((out[j] + 8) >> 4));
-    }
-    memset(coeff, 0, sizeof(*coeff) * sh * sw);
+    memset(coeff, 0, sizeof(*coeff) * sw * sh);
+    for (int i = 0; i < w * sh; i++)
+        tmp[i] = iclip((tmp[i] + rnd) >> shift, col_clip_min, col_clip_max);
+
+    for (int x = 0; x < w; x++)
+        second_1d_fn(&tmp[x], w, col_clip_min, col_clip_max);
+
+    c = tmp;
+    for (int y = 0; y < h; y++, dst += PXSTRIDE(stride))
+        for (int x = 0; x < w; x++)
+            dst[x] = iclip_pixel(dst[x] + ((*c++ + 8) >> 4));
 }
 
 #define inv_txfm_fn(type1, type2, w, h, shift, has_dconly) \
@@ -161,21 +163,21 @@
                                        coef *const coeff, const int eob
                                        HIGHBD_DECL_SUFFIX)
 {
-    int32_t tmp[4 * 4], out[4], in_mem[4];
-
-    for (int i = 0; i < 4; i++) {
-        for (int j = 0; j < 4; j++)
-            in_mem[j] = coeff[i + j * 4];
-        dav1d_inv_wht4_1d_c(in_mem, 1, &tmp[i * 4], 1, 0);
+    int32_t tmp[4 * 4], *c = tmp;
+    for (int y = 0; y < 4; y++, c += 4) {
+        for (int x = 0; x < 4; x++)
+            c[x] = coeff[y + x * 4] >> 2;
+        dav1d_inv_wht4_1d_c(c, 1);
     }
-
-    for (int i = 0; i < 4; i++) {
-        dav1d_inv_wht4_1d_c(&tmp[i], 4, out, 1, 1);
-        for (int j = 0; j < 4; j++)
-            dst[i + j * PXSTRIDE(stride)] =
-                iclip_pixel(dst[i + j * PXSTRIDE(stride)] + out[j]);
-    }
     memset(coeff, 0, sizeof(*coeff) * 4 * 4);
+
+    for (int x = 0; x < 4; x++)
+        dav1d_inv_wht4_1d_c(&tmp[x], 4);
+
+    c = tmp;
+    for (int y = 0; y < 4; y++, dst += PXSTRIDE(stride))
+        for (int x = 0; x < 4; x++)
+            dst[x] = iclip_pixel(dst[x] + *c++);
 }
 
 COLD void bitfn(dav1d_itx_dsp_init)(Dav1dInvTxfmDSPContext *const c) {