shithub: dav1d

Download patch

ref: 61dcd11ba7364a0c6a9cef49e7b99722feee133e
parent: e29fd5c0016fec27c88a36ac6f6eaaf416d91330
author: Henrik Gramner <[email protected]>
date: Sat Aug 10 10:34:59 EDT 2019

x86: Add an msac function for coefficient hi_tok decoding

This particular sequence is executed often enough to justify having
a separate slightly more optimized code path instead of just chaining
multiple generic symbol decoding function calls together.

--- a/src/msac.c
+++ b/src/msac.c
@@ -171,6 +171,22 @@
     return bit;
 }
 
+unsigned dav1d_msac_decode_hi_tok_c(MsacContext *const s, uint16_t *const cdf) {
+    unsigned tok_br = dav1d_msac_decode_symbol_adapt4(s, cdf, 3);
+    unsigned tok = 3 + tok_br;
+    if (tok_br == 3) {
+        tok_br = dav1d_msac_decode_symbol_adapt4(s, cdf, 3);
+        tok = 6 + tok_br;
+        if (tok_br == 3) {
+            tok_br = dav1d_msac_decode_symbol_adapt4(s, cdf, 3);
+            tok = 9 + tok_br;
+            if (tok_br == 3)
+                tok = 12 + dav1d_msac_decode_symbol_adapt4(s, cdf, 3);
+        }
+    }
+    return tok;
+}
+
 void dav1d_msac_init(MsacContext *const s, const uint8_t *const data,
                      const size_t sz, const int disable_cdf_update_flag)
 {
--- a/src/msac.h
+++ b/src/msac.h
@@ -58,6 +58,7 @@
 unsigned dav1d_msac_decode_bool_adapt_c(MsacContext *s, uint16_t *cdf);
 unsigned dav1d_msac_decode_bool_equi_c(MsacContext *s);
 unsigned dav1d_msac_decode_bool_c(MsacContext *s, unsigned f);
+unsigned dav1d_msac_decode_hi_tok_c(MsacContext *s, uint16_t *cdf);
 int dav1d_msac_decode_subexp(MsacContext *s, int ref, int n, unsigned k);
 
 /* Supported n_symbols ranges: adapt4: 1-4, adapt8: 1-7, adapt16: 3-15 */
@@ -78,6 +79,9 @@
 #endif
 #ifndef dav1d_msac_decode_bool
 #define dav1d_msac_decode_bool           dav1d_msac_decode_bool_c
+#endif
+#ifndef dav1d_msac_decode_hi_tok
+#define dav1d_msac_decode_hi_tok         dav1d_msac_decode_hi_tok_c
 #endif
 
 static inline unsigned dav1d_msac_decode_bools(MsacContext *const s, unsigned n) {
--- a/src/recon_tmpl.c
+++ b/src/recon_tmpl.c
@@ -199,40 +199,13 @@
                 printf("Post-lo_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
                        t_dim->ctx, chroma, ctx, eob, rc, tok, ts->msac.rng);
 
-            // hi tok
             if (tok_br == 2) {
-#define dbg_print_hi_tok(i, tok, tok_br) \
-    if (dbg)\
-        printf("Post-hi_tok[%d][%d][%d][%d=%d=%d->%d]: r=%d\n",\
-               imin(t_dim->ctx, 3), chroma, br_ctx, i, rc, tok, tok_br,\
-               ts->msac.rng)
                 const int br_ctx = get_br_ctx(levels, 1, tx_class, x, y, stride);
-
-                tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                            br_cdf[br_ctx], 3);
-                tok = 3 + tok_br;
-                dbg_print_hi_tok(eob, tok + tok_br, tok_br);
-
-                if (tok_br == 3) {
-                    tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                             br_cdf[br_ctx], 3);
-                    tok = 6 + tok_br;
-                    dbg_print_hi_tok(eob, tok + tok_br, tok_br);
-                    if (tok_br == 3) {
-                        tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                                 br_cdf[br_ctx],
-                                                                 3);
-                        tok = 9 + tok_br;
-                        dbg_print_hi_tok(eob, tok + tok_br, tok_br);
-                        if (tok_br == 3) {
-                            tok = 12 +
-                                dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                                br_cdf[br_ctx],
-                                                                3);
-                            dbg_print_hi_tok(eob, tok + tok_br, tok_br);
-                        }
-                    }
-                }
+                tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[br_ctx]);
+                if (dbg)
+                    printf("Post-hi_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
+                           imin(t_dim->ctx, 3), chroma, br_ctx, eob, rc, tok,
+                           ts->msac.rng);
             }
 
             cf[rc] = tok;
@@ -249,37 +222,14 @@
                 printf("Post-lo_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
                        t_dim->ctx, chroma, ctx, i, rc, tok, ts->msac.rng);
 
-            // hi tok
             if (tok == 3) {
                 const int br_ctx = get_br_ctx(levels, 1, tx_class, x, y, stride);
-
-                int tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                             br_cdf[br_ctx], 3);
-                tok = 3 + tok_br;
-                dbg_print_hi_tok(i, tok + tok_br, tok_br);
-
-                if (tok_br == 3) {
-                    tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                             br_cdf[br_ctx], 3);
-
-                    tok = 6 + tok_br;
-                    dbg_print_hi_tok(i, tok + tok_br, tok_br);
-                    if (tok_br == 3) {
-                        tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                                 br_cdf[br_ctx],
-                                                                 3);
-                        tok = 9 + tok_br;
-                        dbg_print_hi_tok(i, tok + tok_br, tok_br);
-                        if (tok_br == 3) {
-                            tok = 12 + dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                                       br_cdf[br_ctx],
-                                                                       3);
-                            dbg_print_hi_tok(i, tok + tok_br, tok_br);
-                        }
-                    }
-                }
+                tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[br_ctx]);
+                if (dbg)
+                    printf("Post-hi_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
+                           imin(t_dim->ctx, 3), chroma, br_ctx, i, rc, tok,
+                           ts->msac.rng);
             }
-#undef dbg_print_hi_tok
             cf[rc] = tok;
             levels[x * stride + y] = (uint8_t) tok;
         }
@@ -292,43 +242,13 @@
                 printf("Post-dc_lo_tok[%d][%d][%d][%d]: r=%d\n",
                        t_dim->ctx, chroma, ctx, dc_tok, ts->msac.rng);
 
-            // hi tok
             if (dc_tok == 3) {
-#define dbg_print_hi_tok(dc_tok, tok_br) \
-    if (dbg) \
-        printf("Post-dc_hi_tok[%d][%d][%d][%d->%d]: r=%d\n", \
-               imin(t_dim->ctx, 3), chroma, br_ctx, tok_br, dc_tok, ts->msac.rng);
-
                 const int br_ctx = get_br_ctx(levels, 0, tx_class, 0, 0, stride);
-
-                int tok_br =
-                    dav1d_msac_decode_symbol_adapt4(&ts->msac, br_cdf[br_ctx], 3);
-                dc_tok = 3 + tok_br;
-
-                dbg_print_hi_tok(dc_tok + tok_br, tok_br);
-
-                if (tok_br == 3) {
-                    tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                             br_cdf[br_ctx], 3);
-                    dc_tok = 6 + tok_br;
-                    dbg_print_hi_tok(dc_tok + tok_br, tok_br);
-                    if (tok_br == 3) {
-                        tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                                 br_cdf[br_ctx],
-                                                                 3);
-                        dc_tok = 9 + tok_br;
-                        dbg_print_hi_tok(dc_tok + tok_br, tok_br);
-                        if (tok_br == 3) {
-                            dc_tok = 12 +
-                                dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                                br_cdf[br_ctx],
-                                                                3);
-                            dbg_print_hi_tok(dc_tok + tok_br, tok_br);
-                        }
-                    }
-                }
+                dc_tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[br_ctx]);
+                if (dbg)
+                    printf("Post-dc_hi_tok[%d][%d][0][%d]: r=%d\n",
+                           imin(t_dim->ctx, 3), chroma, dc_tok, ts->msac.rng);
             }
-#undef dbg_print_hi_tok
         }
     } else { // dc-only
         uint16_t *const lo_cdf = ts->cdf.coef.eob_base_tok[t_dim->ctx][chroma][0];
@@ -338,38 +258,13 @@
             printf("Post-dc_lo_tok[%d][%d][%d][%d]: r=%d\n",
                    t_dim->ctx, chroma, 0, dc_tok, ts->msac.rng);
 
-        // hi tok
         if (tok_br == 2) {
-#define dbg_print_hi_tok(dc_tok, tok_br) \
-    if (dbg) \
-        printf("Post-dc_hi_tok[%d][%d][0][%d->%d]: r=%d\n", \
-               imin(t_dim->ctx, 3), chroma, tok_br, dc_tok, ts->msac.rng);
-
-            tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, br_cdf[0], 3);
-            dc_tok = 3 + tok_br;
-
-            dbg_print_hi_tok(dc_tok + tok_br, tok_br);
-
-            if (tok_br == 3) {
-                tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, br_cdf[0], 3);
-                dc_tok = 6 + tok_br;
-                dbg_print_hi_tok(dc_tok + tok_br, tok_br);
-                if (tok_br == 3) {
-                    tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                             br_cdf[0], 3);
-                    dc_tok = 9 + tok_br;
-                    dbg_print_hi_tok(dc_tok + tok_br, tok_br);
-                    if (tok_br == 3) {
-                        dc_tok = 12 +
-                            dav1d_msac_decode_symbol_adapt4(&ts->msac,
-                                                            br_cdf[0], 3);
-                        dbg_print_hi_tok(dc_tok + tok_br, tok_br);
-                    }
-                }
-            }
+            dc_tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[0]);
+            if (dbg)
+                printf("Post-dc_hi_tok[%d][%d][0][%d]: r=%d\n",
+                       imin(t_dim->ctx, 3), chroma, dc_tok, ts->msac.rng);
         }
     }
-#undef dbg_print_hi_tok
 
     // residual and sign
     int dc_sign = 1 << 6;
--- a/src/x86/msac.asm
+++ b/src/x86/msac.asm
@@ -27,7 +27,7 @@
 
 SECTION_RODATA 64 ; avoids cacheline splits
 
-dw 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0
+min_prob:  dw 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0
 pw_0xff00: times 8 dw 0xff00
 pw_32:     times 8 dw 32
 
@@ -35,21 +35,24 @@
 %define resp   resq
 %define movp   movq
 %define c_shuf q3333
-%define DECODE_SYMBOL_ADAPT_INIT
+%macro DECODE_SYMBOL_ADAPT_INIT 0-1
+%endmacro
 %else
 %define resp   resd
 %define movp   movd
 %define c_shuf q1111
-%macro DECODE_SYMBOL_ADAPT_INIT 0
+%macro DECODE_SYMBOL_ADAPT_INIT 0-1 0 ; hi_tok
     mov            t0, r0m
     mov            t1, r1m
+%if %1 == 0
     mov            t2, r2m
+%endif
 %if STACK_ALIGNMENT >= 16
-    sub           esp, 40
+    sub           esp, 40-%1*4
 %else
     mov           eax, esp
     and           esp, ~15
-    sub           esp, 40
+    sub           esp, 40-%1*4
     mov         [esp], eax
 %endif
 %endmacro
@@ -69,13 +72,13 @@
 SECTION .text
 
 %if WIN64
-DECLARE_REG_TMP 0, 1, 2, 3, 4, 5, 7, 3
-%define buf rsp+8  ; shadow space
+DECLARE_REG_TMP 0, 1, 2, 3, 4, 5, 7, 3, 8
+%define buf rsp+stack_offset+8 ; shadow space
 %elif UNIX64
-DECLARE_REG_TMP 0, 1, 2, 3, 4, 5, 7, 0
+DECLARE_REG_TMP 0, 1, 2, 3, 4, 5, 7, 0, 8
 %define buf rsp-40 ; red zone
 %else
-DECLARE_REG_TMP 2, 3, 4, 1, 5, 6, 5, 2
+DECLARE_REG_TMP 2, 3, 4, 1, 5, 6, 5, 2, 3
 %define buf esp+8
 %endif
 
@@ -440,3 +443,158 @@
     movzx         eax, al
 %endif
     jmp m(msac_decode_symbol_adapt4).renorm3
+
+%macro HI_TOK 1 ; update_cdf
+%if ARCH_X86_64 == 0
+    mov           eax, -24
+%endif
+%%loop:
+%if %1
+    movzx         t2d, word [t1+3*2]
+%endif
+    mova           m1, m0
+    pshuflw        m2, m2, q0000
+    psrlw          m1, 6
+    movd     [buf+12], m2
+    pand           m2, m4
+    psllw          m1, 7
+    pmulhuw        m1, m2
+%if ARCH_X86_64 == 0
+    add           eax, 5
+    mov       [buf+8], eax
+%endif
+    pshuflw        m3, m3, c_shuf
+    paddw          m1, m5
+    movq     [buf+16], m1
+    psubusw        m1, m3
+    pxor           m2, m2
+    pcmpeqw        m1, m2
+    pmovmskb      eax, m1
+%if %1
+    lea           ecx, [t2+80]
+    pcmpeqw        m2, m2
+    shr           ecx, 4
+    cmp           t2d, 32
+    adc           t2d, 0
+    movd           m3, ecx
+    pavgw          m2, m1
+    psubw          m2, m0
+    psubw          m0, m1
+    psraw          m2, m3
+    paddw          m0, m2
+    movq         [t1], m0
+    mov      [t1+3*2], t2w
+%endif
+    tzcnt         eax, eax
+    movzx         ecx, word [buf+rax+16]
+    movzx         t2d, word [buf+rax+14]
+    not            t4
+%if ARCH_X86_64
+    add           t6d, 5
+%endif
+    sub           eax, 5   ; setup for merging the tok_br and tok branches
+    sub           t2d, ecx
+    shl           rcx, gprsize*8-16
+    add            t4, rcx
+    bsr           ecx, t2d
+    xor           ecx, 15
+    shl           t2d, cl
+    shl            t4, cl
+    movd           m2, t2d
+    mov [t7+msac.rng], t2d
+    not            t4
+    sub           t5d, ecx
+    jge %%end
+    mov            t2, [t7+msac.buf]
+    mov           rcx, [t7+msac.end]
+%if UNIX64 == 0
+    push           t8
+%endif
+    lea            t8, [t2+gprsize]
+    cmp            t8, rcx
+    ja %%refill_eob
+    mov            t2, [t2]
+    lea           ecx, [t5+23]
+    add           t5d, 16
+    shr           ecx, 3
+    bswap          t2
+    sub            t8, rcx
+    shl           ecx, 3
+    shr            t2, cl
+    sub           ecx, t5d
+    mov           t5d, gprsize*8-16
+    shl            t2, cl
+    mov [t7+msac.buf], t8
+%if UNIX64 == 0
+    pop            t8
+%endif
+    sub           t5d, ecx
+    xor            t4, t2
+%%end:
+    movp           m3, t4
+%if ARCH_X86_64
+    add           t6d, eax ; CF = tok_br < 3 || tok == 15
+    jnc %%loop
+    lea           eax, [t6+30]
+%else
+    add           eax, [buf+8]
+    jnc %%loop
+    add           eax, 30
+%if STACK_ALIGNMENT >= 16
+    add           esp, 36
+%else
+    mov           esp, [esp]
+%endif
+%endif
+    mov [t7+msac.dif], t4
+    shr           eax, 1
+    mov [t7+msac.cnt], t5d
+    RET
+%%refill_eob:
+    mov            t8, rcx
+    mov           ecx, gprsize*8-24
+    sub           ecx, t5d
+%%refill_eob_loop:
+    cmp            t2, t8
+    jae %%refill_eob_end
+    movzx         t5d, byte [t2]
+    inc            t2
+    shl            t5, cl
+    xor            t4, t5
+    sub           ecx, 8
+    jge %%refill_eob_loop
+%%refill_eob_end:
+%if UNIX64 == 0
+    pop            t8
+%endif
+    mov           t5d, gprsize*8-24
+    mov [t7+msac.buf], t2
+    sub           t5d, ecx
+    jmp %%end
+%endmacro
+
+cglobal msac_decode_hi_tok, 0, 7 + ARCH_X86_64, 6
+    DECODE_SYMBOL_ADAPT_INIT 1
+%if ARCH_X86_64 == 0 && PIC
+    LEA            t2, min_prob+12*2
+    %define base t2-(min_prob+12*2)
+%else
+    %define base 0
+%endif
+    movq           m0, [t1]
+    movd           m2, [t0+msac.rng]
+    mov           eax, [t0+msac.update_cdf]
+    movq           m4, [base+pw_0xff00]
+    movp           m3, [t0+msac.dif]
+    movq           m5, [base+min_prob+12*2]
+    mov            t4, [t0+msac.dif]
+    mov           t5d, [t0+msac.cnt]
+%if ARCH_X86_64
+    mov           t6d, -24
+%endif
+    movifnidn      t7, t0
+    test          eax, eax
+    jz .no_update_cdf
+    HI_TOK          1
+.no_update_cdf:
+    HI_TOK          0
--- a/src/x86/msac.h
+++ b/src/x86/msac.h
@@ -37,11 +37,13 @@
 unsigned dav1d_msac_decode_bool_adapt_sse2(MsacContext *s, uint16_t *cdf);
 unsigned dav1d_msac_decode_bool_equi_sse2(MsacContext *s);
 unsigned dav1d_msac_decode_bool_sse2(MsacContext *s, unsigned f);
+unsigned dav1d_msac_decode_hi_tok_sse2(MsacContext *s, uint16_t *cdf);
 
 #if ARCH_X86_64 || defined(__SSE2__) || (defined(_M_IX86_FP) && _M_IX86_FP >= 2)
 #define dav1d_msac_decode_symbol_adapt4  dav1d_msac_decode_symbol_adapt4_sse2
 #define dav1d_msac_decode_symbol_adapt8  dav1d_msac_decode_symbol_adapt8_sse2
 #define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_sse2
+#define dav1d_msac_decode_hi_tok         dav1d_msac_decode_hi_tok_sse2
 #endif
 
 #define dav1d_msac_decode_bool_adapt     dav1d_msac_decode_bool_adapt_sse2
--- a/tests/checkasm/msac.c
+++ b/tests/checkasm/msac.c
@@ -38,7 +38,7 @@
 /* The normal code doesn't use function pointers */
 typedef unsigned (*decode_symbol_adapt_fn)(MsacContext *s, uint16_t *cdf,
                                            size_t n_symbols);
-typedef unsigned (*decode_bool_adapt_fn)(MsacContext *s, uint16_t *cdf);
+typedef unsigned (*decode_adapt_fn)(MsacContext *s, uint16_t *cdf);
 typedef unsigned (*decode_bool_equi_fn)(MsacContext *s);
 typedef unsigned (*decode_bool_fn)(MsacContext *s, unsigned f);
 
@@ -46,9 +46,10 @@
     decode_symbol_adapt_fn symbol_adapt4;
     decode_symbol_adapt_fn symbol_adapt8;
     decode_symbol_adapt_fn symbol_adapt16;
-    decode_bool_adapt_fn   bool_adapt;
+    decode_adapt_fn        bool_adapt;
     decode_bool_equi_fn    bool_equi;
     decode_bool_fn         bool;
+    decode_adapt_fn        hi_tok;
 } MsacDSPContext;
 
 static void randomize_cdf(uint16_t *const cdf, const int n) {
@@ -199,6 +200,35 @@
     report("decode_bool");
 }
 
+static void check_decode_hi_tok(MsacDSPContext *const c, uint8_t *const buf) {
+    ALIGN_STK_16(uint16_t, cdf, 2, [16]);
+    MsacContext s_c, s_a;
+
+    if (check_func(c->hi_tok, "msac_decode_hi_tok")) {
+        declare_func(unsigned, MsacContext *s, uint16_t *cdf);
+        for (int cdf_update = 0; cdf_update <= 1; cdf_update++) {
+            dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update);
+            s_a = s_c;
+            randomize_cdf(cdf[0], 3);
+            memcpy(cdf[1], cdf[0], sizeof(*cdf));
+            for (int i = 0; i < 64; i++) {
+                unsigned c_res = call_ref(&s_c, cdf[0]);
+                unsigned a_res = call_new(&s_a, cdf[1]);
+                if (c_res != a_res || msac_cmp(&s_c, &s_a) ||
+                    memcmp(cdf[0], cdf[1], sizeof(*cdf)))
+                {
+                    if (fail())
+                        msac_dump(c_res, a_res, &s_c, &s_a, cdf[0], cdf[1], 3);
+                    break;
+                }
+            }
+            if (cdf_update)
+                bench_new(&s_a, cdf[1]);
+        }
+    }
+    report("decode_hi_tok");
+}
+
 void checkasm_check_msac(void) {
     MsacDSPContext c;
     c.symbol_adapt4  = dav1d_msac_decode_symbol_adapt_c;
@@ -207,6 +237,7 @@
     c.bool_adapt     = dav1d_msac_decode_bool_adapt_c;
     c.bool_equi      = dav1d_msac_decode_bool_equi_c;
     c.bool           = dav1d_msac_decode_bool_c;
+    c.hi_tok         = dav1d_msac_decode_hi_tok_c;
 
 #if ARCH_AARCH64 && HAVE_ASM
     if (dav1d_get_cpu_flags() & DAV1D_ARM_CPU_FLAG_NEON) {
@@ -225,6 +256,7 @@
         c.bool_adapt     = dav1d_msac_decode_bool_adapt_sse2;
         c.bool_equi      = dav1d_msac_decode_bool_equi_sse2;
         c.bool           = dav1d_msac_decode_bool_sse2;
+        c.hi_tok         = dav1d_msac_decode_hi_tok_sse2;
     }
 #endif
 
@@ -234,4 +266,5 @@
 
     check_decode_symbol(&c, buf);
     check_decode_bool(&c, buf);
+    check_decode_hi_tok(&c, buf);
 }