ref: 61dcd11ba7364a0c6a9cef49e7b99722feee133e
parent: e29fd5c0016fec27c88a36ac6f6eaaf416d91330
author: Henrik Gramner <[email protected]>
date: Sat Aug 10 10:34:59 EDT 2019
x86: Add an msac function for coefficient hi_tok decoding This particular sequence is executed often enough to justify having a separate slightly more optimized code path instead of just chaining multiple generic symbol decoding function calls together.
--- a/src/msac.c
+++ b/src/msac.c
@@ -171,6 +171,22 @@
return bit;
}
+unsigned dav1d_msac_decode_hi_tok_c(MsacContext *const s, uint16_t *const cdf) {
+ unsigned tok_br = dav1d_msac_decode_symbol_adapt4(s, cdf, 3);
+ unsigned tok = 3 + tok_br;
+ if (tok_br == 3) {
+ tok_br = dav1d_msac_decode_symbol_adapt4(s, cdf, 3);
+ tok = 6 + tok_br;
+ if (tok_br == 3) {
+ tok_br = dav1d_msac_decode_symbol_adapt4(s, cdf, 3);
+ tok = 9 + tok_br;
+ if (tok_br == 3)
+ tok = 12 + dav1d_msac_decode_symbol_adapt4(s, cdf, 3);
+ }
+ }
+ return tok;
+}
+
void dav1d_msac_init(MsacContext *const s, const uint8_t *const data,
const size_t sz, const int disable_cdf_update_flag)
{
--- a/src/msac.h
+++ b/src/msac.h
@@ -58,6 +58,7 @@
unsigned dav1d_msac_decode_bool_adapt_c(MsacContext *s, uint16_t *cdf);
unsigned dav1d_msac_decode_bool_equi_c(MsacContext *s);
unsigned dav1d_msac_decode_bool_c(MsacContext *s, unsigned f);
+unsigned dav1d_msac_decode_hi_tok_c(MsacContext *s, uint16_t *cdf);
int dav1d_msac_decode_subexp(MsacContext *s, int ref, int n, unsigned k);
/* Supported n_symbols ranges: adapt4: 1-4, adapt8: 1-7, adapt16: 3-15 */
@@ -78,6 +79,9 @@
#endif
#ifndef dav1d_msac_decode_bool
#define dav1d_msac_decode_bool dav1d_msac_decode_bool_c
+#endif
+#ifndef dav1d_msac_decode_hi_tok
+#define dav1d_msac_decode_hi_tok dav1d_msac_decode_hi_tok_c
#endif
static inline unsigned dav1d_msac_decode_bools(MsacContext *const s, unsigned n) {
--- a/src/recon_tmpl.c
+++ b/src/recon_tmpl.c
@@ -199,40 +199,13 @@
printf("Post-lo_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
t_dim->ctx, chroma, ctx, eob, rc, tok, ts->msac.rng);
- // hi tok
if (tok_br == 2) {
-#define dbg_print_hi_tok(i, tok, tok_br) \
- if (dbg)\
- printf("Post-hi_tok[%d][%d][%d][%d=%d=%d->%d]: r=%d\n",\
- imin(t_dim->ctx, 3), chroma, br_ctx, i, rc, tok, tok_br,\
- ts->msac.rng)
const int br_ctx = get_br_ctx(levels, 1, tx_class, x, y, stride);
-
- tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
- br_cdf[br_ctx], 3);
- tok = 3 + tok_br;
- dbg_print_hi_tok(eob, tok + tok_br, tok_br);
-
- if (tok_br == 3) {
- tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
- br_cdf[br_ctx], 3);
- tok = 6 + tok_br;
- dbg_print_hi_tok(eob, tok + tok_br, tok_br);
- if (tok_br == 3) {
- tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
- br_cdf[br_ctx],
- 3);
- tok = 9 + tok_br;
- dbg_print_hi_tok(eob, tok + tok_br, tok_br);
- if (tok_br == 3) {
- tok = 12 +
- dav1d_msac_decode_symbol_adapt4(&ts->msac,
- br_cdf[br_ctx],
- 3);
- dbg_print_hi_tok(eob, tok + tok_br, tok_br);
- }
- }
- }
+ tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[br_ctx]);
+ if (dbg)
+ printf("Post-hi_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
+ imin(t_dim->ctx, 3), chroma, br_ctx, eob, rc, tok,
+ ts->msac.rng);
}
cf[rc] = tok;
@@ -249,37 +222,14 @@
printf("Post-lo_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
t_dim->ctx, chroma, ctx, i, rc, tok, ts->msac.rng);
- // hi tok
if (tok == 3) {
const int br_ctx = get_br_ctx(levels, 1, tx_class, x, y, stride);
-
- int tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
- br_cdf[br_ctx], 3);
- tok = 3 + tok_br;
- dbg_print_hi_tok(i, tok + tok_br, tok_br);
-
- if (tok_br == 3) {
- tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
- br_cdf[br_ctx], 3);
-
- tok = 6 + tok_br;
- dbg_print_hi_tok(i, tok + tok_br, tok_br);
- if (tok_br == 3) {
- tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
- br_cdf[br_ctx],
- 3);
- tok = 9 + tok_br;
- dbg_print_hi_tok(i, tok + tok_br, tok_br);
- if (tok_br == 3) {
- tok = 12 + dav1d_msac_decode_symbol_adapt4(&ts->msac,
- br_cdf[br_ctx],
- 3);
- dbg_print_hi_tok(i, tok + tok_br, tok_br);
- }
- }
- }
+ tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[br_ctx]);
+ if (dbg)
+ printf("Post-hi_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
+ imin(t_dim->ctx, 3), chroma, br_ctx, i, rc, tok,
+ ts->msac.rng);
}
-#undef dbg_print_hi_tok
cf[rc] = tok;
levels[x * stride + y] = (uint8_t) tok;
}
@@ -292,43 +242,13 @@
printf("Post-dc_lo_tok[%d][%d][%d][%d]: r=%d\n",
t_dim->ctx, chroma, ctx, dc_tok, ts->msac.rng);
- // hi tok
if (dc_tok == 3) {
-#define dbg_print_hi_tok(dc_tok, tok_br) \
- if (dbg) \
- printf("Post-dc_hi_tok[%d][%d][%d][%d->%d]: r=%d\n", \
- imin(t_dim->ctx, 3), chroma, br_ctx, tok_br, dc_tok, ts->msac.rng);
-
const int br_ctx = get_br_ctx(levels, 0, tx_class, 0, 0, stride);
-
- int tok_br =
- dav1d_msac_decode_symbol_adapt4(&ts->msac, br_cdf[br_ctx], 3);
- dc_tok = 3 + tok_br;
-
- dbg_print_hi_tok(dc_tok + tok_br, tok_br);
-
- if (tok_br == 3) {
- tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
- br_cdf[br_ctx], 3);
- dc_tok = 6 + tok_br;
- dbg_print_hi_tok(dc_tok + tok_br, tok_br);
- if (tok_br == 3) {
- tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
- br_cdf[br_ctx],
- 3);
- dc_tok = 9 + tok_br;
- dbg_print_hi_tok(dc_tok + tok_br, tok_br);
- if (tok_br == 3) {
- dc_tok = 12 +
- dav1d_msac_decode_symbol_adapt4(&ts->msac,
- br_cdf[br_ctx],
- 3);
- dbg_print_hi_tok(dc_tok + tok_br, tok_br);
- }
- }
- }
+ dc_tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[br_ctx]);
+ if (dbg)
+ printf("Post-dc_hi_tok[%d][%d][0][%d]: r=%d\n",
+ imin(t_dim->ctx, 3), chroma, dc_tok, ts->msac.rng);
}
-#undef dbg_print_hi_tok
}
} else { // dc-only
uint16_t *const lo_cdf = ts->cdf.coef.eob_base_tok[t_dim->ctx][chroma][0];
@@ -338,38 +258,13 @@
printf("Post-dc_lo_tok[%d][%d][%d][%d]: r=%d\n",
t_dim->ctx, chroma, 0, dc_tok, ts->msac.rng);
- // hi tok
if (tok_br == 2) {
-#define dbg_print_hi_tok(dc_tok, tok_br) \
- if (dbg) \
- printf("Post-dc_hi_tok[%d][%d][0][%d->%d]: r=%d\n", \
- imin(t_dim->ctx, 3), chroma, tok_br, dc_tok, ts->msac.rng);
-
- tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, br_cdf[0], 3);
- dc_tok = 3 + tok_br;
-
- dbg_print_hi_tok(dc_tok + tok_br, tok_br);
-
- if (tok_br == 3) {
- tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, br_cdf[0], 3);
- dc_tok = 6 + tok_br;
- dbg_print_hi_tok(dc_tok + tok_br, tok_br);
- if (tok_br == 3) {
- tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
- br_cdf[0], 3);
- dc_tok = 9 + tok_br;
- dbg_print_hi_tok(dc_tok + tok_br, tok_br);
- if (tok_br == 3) {
- dc_tok = 12 +
- dav1d_msac_decode_symbol_adapt4(&ts->msac,
- br_cdf[0], 3);
- dbg_print_hi_tok(dc_tok + tok_br, tok_br);
- }
- }
- }
+ dc_tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[0]);
+ if (dbg)
+ printf("Post-dc_hi_tok[%d][%d][0][%d]: r=%d\n",
+ imin(t_dim->ctx, 3), chroma, dc_tok, ts->msac.rng);
}
}
-#undef dbg_print_hi_tok
// residual and sign
int dc_sign = 1 << 6;
--- a/src/x86/msac.asm
+++ b/src/x86/msac.asm
@@ -27,7 +27,7 @@
SECTION_RODATA 64 ; avoids cacheline splits
-dw 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0
+min_prob: dw 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0
pw_0xff00: times 8 dw 0xff00
pw_32: times 8 dw 32
@@ -35,21 +35,24 @@
%define resp resq
%define movp movq
%define c_shuf q3333
-%define DECODE_SYMBOL_ADAPT_INIT
+%macro DECODE_SYMBOL_ADAPT_INIT 0-1
+%endmacro
%else
%define resp resd
%define movp movd
%define c_shuf q1111
-%macro DECODE_SYMBOL_ADAPT_INIT 0
+%macro DECODE_SYMBOL_ADAPT_INIT 0-1 0 ; hi_tok
mov t0, r0m
mov t1, r1m
+%if %1 == 0
mov t2, r2m
+%endif
%if STACK_ALIGNMENT >= 16
- sub esp, 40
+ sub esp, 40-%1*4
%else
mov eax, esp
and esp, ~15
- sub esp, 40
+ sub esp, 40-%1*4
mov [esp], eax
%endif
%endmacro
@@ -69,13 +72,13 @@
SECTION .text
%if WIN64
-DECLARE_REG_TMP 0, 1, 2, 3, 4, 5, 7, 3
-%define buf rsp+8 ; shadow space
+DECLARE_REG_TMP 0, 1, 2, 3, 4, 5, 7, 3, 8
+%define buf rsp+stack_offset+8 ; shadow space
%elif UNIX64
-DECLARE_REG_TMP 0, 1, 2, 3, 4, 5, 7, 0
+DECLARE_REG_TMP 0, 1, 2, 3, 4, 5, 7, 0, 8
%define buf rsp-40 ; red zone
%else
-DECLARE_REG_TMP 2, 3, 4, 1, 5, 6, 5, 2
+DECLARE_REG_TMP 2, 3, 4, 1, 5, 6, 5, 2, 3
%define buf esp+8
%endif
@@ -440,3 +443,158 @@
movzx eax, al
%endif
jmp m(msac_decode_symbol_adapt4).renorm3
+
+%macro HI_TOK 1 ; update_cdf
+%if ARCH_X86_64 == 0
+ mov eax, -24
+%endif
+%%loop:
+%if %1
+ movzx t2d, word [t1+3*2]
+%endif
+ mova m1, m0
+ pshuflw m2, m2, q0000
+ psrlw m1, 6
+ movd [buf+12], m2
+ pand m2, m4
+ psllw m1, 7
+ pmulhuw m1, m2
+%if ARCH_X86_64 == 0
+ add eax, 5
+ mov [buf+8], eax
+%endif
+ pshuflw m3, m3, c_shuf
+ paddw m1, m5
+ movq [buf+16], m1
+ psubusw m1, m3
+ pxor m2, m2
+ pcmpeqw m1, m2
+ pmovmskb eax, m1
+%if %1
+ lea ecx, [t2+80]
+ pcmpeqw m2, m2
+ shr ecx, 4
+ cmp t2d, 32
+ adc t2d, 0
+ movd m3, ecx
+ pavgw m2, m1
+ psubw m2, m0
+ psubw m0, m1
+ psraw m2, m3
+ paddw m0, m2
+ movq [t1], m0
+ mov [t1+3*2], t2w
+%endif
+ tzcnt eax, eax
+ movzx ecx, word [buf+rax+16]
+ movzx t2d, word [buf+rax+14]
+ not t4
+%if ARCH_X86_64
+ add t6d, 5
+%endif
+ sub eax, 5 ; setup for merging the tok_br and tok branches
+ sub t2d, ecx
+ shl rcx, gprsize*8-16
+ add t4, rcx
+ bsr ecx, t2d
+ xor ecx, 15
+ shl t2d, cl
+ shl t4, cl
+ movd m2, t2d
+ mov [t7+msac.rng], t2d
+ not t4
+ sub t5d, ecx
+ jge %%end
+ mov t2, [t7+msac.buf]
+ mov rcx, [t7+msac.end]
+%if UNIX64 == 0
+ push t8
+%endif
+ lea t8, [t2+gprsize]
+ cmp t8, rcx
+ ja %%refill_eob
+ mov t2, [t2]
+ lea ecx, [t5+23]
+ add t5d, 16
+ shr ecx, 3
+ bswap t2
+ sub t8, rcx
+ shl ecx, 3
+ shr t2, cl
+ sub ecx, t5d
+ mov t5d, gprsize*8-16
+ shl t2, cl
+ mov [t7+msac.buf], t8
+%if UNIX64 == 0
+ pop t8
+%endif
+ sub t5d, ecx
+ xor t4, t2
+%%end:
+ movp m3, t4
+%if ARCH_X86_64
+ add t6d, eax ; CF = tok_br < 3 || tok == 15
+ jnc %%loop
+ lea eax, [t6+30]
+%else
+ add eax, [buf+8]
+ jnc %%loop
+ add eax, 30
+%if STACK_ALIGNMENT >= 16
+ add esp, 36
+%else
+ mov esp, [esp]
+%endif
+%endif
+ mov [t7+msac.dif], t4
+ shr eax, 1
+ mov [t7+msac.cnt], t5d
+ RET
+%%refill_eob:
+ mov t8, rcx
+ mov ecx, gprsize*8-24
+ sub ecx, t5d
+%%refill_eob_loop:
+ cmp t2, t8
+ jae %%refill_eob_end
+ movzx t5d, byte [t2]
+ inc t2
+ shl t5, cl
+ xor t4, t5
+ sub ecx, 8
+ jge %%refill_eob_loop
+%%refill_eob_end:
+%if UNIX64 == 0
+ pop t8
+%endif
+ mov t5d, gprsize*8-24
+ mov [t7+msac.buf], t2
+ sub t5d, ecx
+ jmp %%end
+%endmacro
+
+cglobal msac_decode_hi_tok, 0, 7 + ARCH_X86_64, 6
+ DECODE_SYMBOL_ADAPT_INIT 1
+%if ARCH_X86_64 == 0 && PIC
+ LEA t2, min_prob+12*2
+ %define base t2-(min_prob+12*2)
+%else
+ %define base 0
+%endif
+ movq m0, [t1]
+ movd m2, [t0+msac.rng]
+ mov eax, [t0+msac.update_cdf]
+ movq m4, [base+pw_0xff00]
+ movp m3, [t0+msac.dif]
+ movq m5, [base+min_prob+12*2]
+ mov t4, [t0+msac.dif]
+ mov t5d, [t0+msac.cnt]
+%if ARCH_X86_64
+ mov t6d, -24
+%endif
+ movifnidn t7, t0
+ test eax, eax
+ jz .no_update_cdf
+ HI_TOK 1
+.no_update_cdf:
+ HI_TOK 0
--- a/src/x86/msac.h
+++ b/src/x86/msac.h
@@ -37,11 +37,13 @@
unsigned dav1d_msac_decode_bool_adapt_sse2(MsacContext *s, uint16_t *cdf);
unsigned dav1d_msac_decode_bool_equi_sse2(MsacContext *s);
unsigned dav1d_msac_decode_bool_sse2(MsacContext *s, unsigned f);
+unsigned dav1d_msac_decode_hi_tok_sse2(MsacContext *s, uint16_t *cdf);
#if ARCH_X86_64 || defined(__SSE2__) || (defined(_M_IX86_FP) && _M_IX86_FP >= 2)
#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt4_sse2
#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt8_sse2
#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_sse2
+#define dav1d_msac_decode_hi_tok dav1d_msac_decode_hi_tok_sse2
#endif
#define dav1d_msac_decode_bool_adapt dav1d_msac_decode_bool_adapt_sse2
--- a/tests/checkasm/msac.c
+++ b/tests/checkasm/msac.c
@@ -38,7 +38,7 @@
/* The normal code doesn't use function pointers */
typedef unsigned (*decode_symbol_adapt_fn)(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
-typedef unsigned (*decode_bool_adapt_fn)(MsacContext *s, uint16_t *cdf);
+typedef unsigned (*decode_adapt_fn)(MsacContext *s, uint16_t *cdf);
typedef unsigned (*decode_bool_equi_fn)(MsacContext *s);
typedef unsigned (*decode_bool_fn)(MsacContext *s, unsigned f);
@@ -46,9 +46,10 @@
decode_symbol_adapt_fn symbol_adapt4;
decode_symbol_adapt_fn symbol_adapt8;
decode_symbol_adapt_fn symbol_adapt16;
- decode_bool_adapt_fn bool_adapt;
+ decode_adapt_fn bool_adapt;
decode_bool_equi_fn bool_equi;
decode_bool_fn bool;
+ decode_adapt_fn hi_tok;
} MsacDSPContext;
static void randomize_cdf(uint16_t *const cdf, const int n) {
@@ -199,6 +200,35 @@
report("decode_bool");
}
+static void check_decode_hi_tok(MsacDSPContext *const c, uint8_t *const buf) {
+ ALIGN_STK_16(uint16_t, cdf, 2, [16]);
+ MsacContext s_c, s_a;
+
+ if (check_func(c->hi_tok, "msac_decode_hi_tok")) {
+ declare_func(unsigned, MsacContext *s, uint16_t *cdf);
+ for (int cdf_update = 0; cdf_update <= 1; cdf_update++) {
+ dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update);
+ s_a = s_c;
+ randomize_cdf(cdf[0], 3);
+ memcpy(cdf[1], cdf[0], sizeof(*cdf));
+ for (int i = 0; i < 64; i++) {
+ unsigned c_res = call_ref(&s_c, cdf[0]);
+ unsigned a_res = call_new(&s_a, cdf[1]);
+ if (c_res != a_res || msac_cmp(&s_c, &s_a) ||
+ memcmp(cdf[0], cdf[1], sizeof(*cdf)))
+ {
+ if (fail())
+ msac_dump(c_res, a_res, &s_c, &s_a, cdf[0], cdf[1], 3);
+ break;
+ }
+ }
+ if (cdf_update)
+ bench_new(&s_a, cdf[1]);
+ }
+ }
+ report("decode_hi_tok");
+}
+
void checkasm_check_msac(void) {
MsacDSPContext c;
c.symbol_adapt4 = dav1d_msac_decode_symbol_adapt_c;
@@ -207,6 +237,7 @@
c.bool_adapt = dav1d_msac_decode_bool_adapt_c;
c.bool_equi = dav1d_msac_decode_bool_equi_c;
c.bool = dav1d_msac_decode_bool_c;
+ c.hi_tok = dav1d_msac_decode_hi_tok_c;
#if ARCH_AARCH64 && HAVE_ASM
if (dav1d_get_cpu_flags() & DAV1D_ARM_CPU_FLAG_NEON) {
@@ -225,6 +256,7 @@
c.bool_adapt = dav1d_msac_decode_bool_adapt_sse2;
c.bool_equi = dav1d_msac_decode_bool_equi_sse2;
c.bool = dav1d_msac_decode_bool_sse2;
+ c.hi_tok = dav1d_msac_decode_hi_tok_sse2;
}
#endif
@@ -234,4 +266,5 @@
check_decode_symbol(&c, buf);
check_decode_bool(&c, buf);
+ check_decode_hi_tok(&c, buf);
}