shithub: dav1d

Download patch

ref: 93c4bea2d45d7caf5cc6ab712d938dc6f74b98a2
parent: 0ba64ee5a44491daa230e686228803316a4d1f9d
author: Henrik Gramner <[email protected]>
date: Fri Oct 19 21:32:21 EDT 2018

x86: Add pal_pred AVX2 asm

--- a/src/ipred.h
+++ b/src/ipred.h
@@ -70,7 +70,7 @@
  */
 #define decl_pal_pred_fn(name) \
 void (name)(pixel *dst, ptrdiff_t stride, const uint16_t *pal, \
-            const uint8_t *idx, const int w, const int h)
+            const uint8_t *idx, int w, int h)
 typedef decl_pal_pred_fn(*pal_pred_fn);
 
 typedef struct Dav1dIntraPredDSPContext {
--- a/src/x86/ipred.asm
+++ b/src/x86/ipred.asm
@@ -93,6 +93,7 @@
 JMP_TABLE ipred_cfl,      avx2, h4, h8, h16, h32, w4, w8, w16, w32, \
                                 s4-8*4, s8-8*4, s16-8*4, s32-8*4
 JMP_TABLE ipred_cfl_left, avx2, h4, h8, h16, h32
+JMP_TABLE pal_pred,       avx2, w4, w8, w16, w32, w64
 
 SECTION .text
 
@@ -1514,5 +1515,83 @@
     add                  wq, t0
     movifnidn           acq, acmp
     jmp                  wq
+
+cglobal pal_pred, 4, 6, 5, dst, stride, pal, idx, w, h
+    vbroadcasti128       m4, [palq]
+    lea                  r2, [pal_pred_avx2_table]
+    tzcnt                wd, wm
+    movifnidn            hd, hm
+    movsxd               wq, [r2+wq*4]
+    packuswb             m4, m4
+    add                  wq, r2
+    lea                  r2, [strideq*3]
+    jmp                  wq
+.w4:
+    pshufb              xm0, xm4, [idxq]
+    add                idxq, 16
+    movd   [dstq+strideq*0], xm0
+    pextrd [dstq+strideq*1], xm0, 1
+    pextrd [dstq+strideq*2], xm0, 2
+    pextrd [dstq+r2       ], xm0, 3
+    lea                dstq, [dstq+strideq*4]
+    sub                  hd, 4
+    jg .w4
+    RET
+ALIGN function_align
+.w8:
+    pshufb              xm0, xm4, [idxq+16*0]
+    pshufb              xm1, xm4, [idxq+16*1]
+    add                idxq, 16*2
+    movq   [dstq+strideq*0], xm0
+    movhps [dstq+strideq*1], xm0
+    movq   [dstq+strideq*2], xm1
+    movhps [dstq+r2       ], xm1
+    lea                dstq, [dstq+strideq*4]
+    sub                  hd, 4
+    jg .w8
+    RET
+ALIGN function_align
+.w16:
+    pshufb               m0, m4, [idxq+32*0]
+    pshufb               m1, m4, [idxq+32*1]
+    add                idxq, 32*2
+    mova         [dstq+strideq*0], xm0
+    vextracti128 [dstq+strideq*1], m0, 1
+    mova         [dstq+strideq*2], xm1
+    vextracti128 [dstq+r2       ], m1, 1
+    lea                dstq, [dstq+strideq*4]
+    sub                  hd, 4
+    jg .w16
+    RET
+ALIGN function_align
+.w32:
+    pshufb               m0, m4, [idxq+32*0]
+    pshufb               m1, m4, [idxq+32*1]
+    pshufb               m2, m4, [idxq+32*2]
+    pshufb               m3, m4, [idxq+32*3]
+    add                idxq, 32*4
+    mova   [dstq+strideq*0], m0
+    mova   [dstq+strideq*1], m1
+    mova   [dstq+strideq*2], m2
+    mova   [dstq+r2       ], m3
+    lea                dstq, [dstq+strideq*4]
+    sub                  hd, 4
+    jg .w32
+    RET
+ALIGN function_align
+.w64:
+    pshufb               m0, m4, [idxq+32*0]
+    pshufb               m1, m4, [idxq+32*1]
+    pshufb               m2, m4, [idxq+32*2]
+    pshufb               m3, m4, [idxq+32*3]
+    add                idxq, 32*4
+    mova [dstq+strideq*0+32*0], m0
+    mova [dstq+strideq*0+32*1], m1
+    mova [dstq+strideq*1+32*0], m2
+    mova [dstq+strideq*1+32*1], m3
+    lea                dstq, [dstq+strideq*2]
+    sub                  hd, 2
+    jg .w64
+    RET
 
 %endif
--- a/src/x86/ipred_init.c
+++ b/src/x86/ipred_init.c
@@ -44,6 +44,8 @@
 decl_cfl_pred_fn(dav1d_ipred_cfl_top_avx2);
 decl_cfl_pred_fn(dav1d_ipred_cfl_left_avx2);
 
+decl_pal_pred_fn(dav1d_pal_pred_avx2);
+
 void bitfn(dav1d_intra_pred_dsp_init_x86)(Dav1dIntraPredDSPContext *const c) {
     const unsigned flags = dav1d_get_cpu_flags();
 
@@ -65,5 +67,7 @@
     c->cfl_pred[DC_128_PRED]  = dav1d_ipred_cfl_128_avx2;
     c->cfl_pred[TOP_DC_PRED]  = dav1d_ipred_cfl_top_avx2;
     c->cfl_pred[LEFT_DC_PRED] = dav1d_ipred_cfl_left_avx2;
+
+    c->pal_pred = dav1d_pal_pred_avx2;
 #endif
 }
--- a/tests/checkasm/ipred.c
+++ b/tests/checkasm/ipred.c
@@ -142,6 +142,37 @@
     report("cfl_pred");
 }
 
+static void check_pal_pred(Dav1dIntraPredDSPContext *const c) {
+    ALIGN_STK_32(pixel, c_dst, 64 * 64,);
+    ALIGN_STK_32(pixel, a_dst, 64 * 64,);
+    ALIGN_STK_32(uint8_t, idx, 64 * 64,);
+    ALIGN_STK_16(uint16_t, pal, 8,);
+
+    declare_func(void, pixel *dst, ptrdiff_t stride, const uint16_t *pal,
+                 const uint8_t *idx, int w, int h);
+
+    for (int w = 4; w <= 64; w <<= 1)
+        if (check_func(c->pal_pred, "pal_pred_w%d_%dbpc", w, BITDEPTH))
+            for (int h = imax(w / 4, 4); h <= imin(w * 4, 64); h <<= 1)
+            {
+                const ptrdiff_t stride = w * sizeof(pixel);
+
+                for (int i = 0; i < 8; i++)
+                    pal[i] = rand() & ((1 << BITDEPTH) - 1);
+
+                for (int i = 0; i < w * h; i++)
+                    idx[i] = rand() & 7;
+
+                call_ref(c_dst, stride, pal, idx, w, h);
+                call_new(a_dst, stride, pal, idx, w, h);
+                if (memcmp(c_dst, a_dst, w * h * sizeof(*c_dst)))
+                    fail();
+
+                bench_new(a_dst, stride, pal, idx, w, h);
+            }
+    report("pal_pred");
+}
+
 void bitfn(checkasm_check_ipred)(void) {
     Dav1dIntraPredDSPContext c;
     bitfn(dav1d_intra_pred_dsp_init)(&c);
@@ -148,4 +179,5 @@
 
     check_intra_pred(&c);
     check_cfl_pred(&c);
+    check_pal_pred(&c);
 }