ref: e16e2726e8c1019bfdd62b73caa8cba255b7ee1c
parent: e25ed5550ef7304e8324fa2de981522e7cb14ec4
author: Henrik Gramner <[email protected]>
date: Tue May 14 16:13:06 EDT 2019
x86-64: Add msac_decode_bool and msac_decode_bool_adapt asm
--- a/src/msac.c
+++ b/src/msac.c
@@ -85,7 +85,7 @@
/* Decode a single binary value.
* f: The probability that the bit is one
* Return: The value decoded (0 or 1). */
-unsigned dav1d_msac_decode_bool(MsacContext *const s, const unsigned f) {
+unsigned dav1d_msac_decode_bool_c(MsacContext *const s, const unsigned f) {
ec_win vw, dif = s->dif;
unsigned ret, v, r = s->rng;
assert((dif >> (EC_WIN_SIZE - 16)) < r);
@@ -155,8 +155,8 @@
return val;
}
-unsigned dav1d_msac_decode_bool_adapt(MsacContext *const s,
- uint16_t *const cdf)
+unsigned dav1d_msac_decode_bool_adapt_c(MsacContext *const s,
+ uint16_t *const cdf)
{
const unsigned bit = dav1d_msac_decode_bool(s, *cdf);
@@ -164,11 +164,10 @@
// update_cdf() specialized for boolean CDFs
const unsigned count = cdf[1];
const int rate = (count >> 4) | 4;
- if (bit) {
+ if (bit)
cdf[0] += (32768 - cdf[0]) >> rate;
- } else {
+ else
cdf[0] -= cdf[0] >> rate;
- }
cdf[1] = count + (count < 32);
}
--- a/src/msac.h
+++ b/src/msac.h
@@ -48,9 +48,9 @@
int disable_cdf_update_flag);
unsigned dav1d_msac_decode_symbol_adapt_c(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
+unsigned dav1d_msac_decode_bool_adapt_c(MsacContext *s, uint16_t *cdf);
unsigned dav1d_msac_decode_bool_equi_c(MsacContext *s);
-unsigned dav1d_msac_decode_bool(MsacContext *s, unsigned f);
-unsigned dav1d_msac_decode_bool_adapt(MsacContext *s, uint16_t *cdf);
+unsigned dav1d_msac_decode_bool_c(MsacContext *s, unsigned f);
int dav1d_msac_decode_subexp(MsacContext *s, int ref, int n, unsigned k);
/* Supported n_symbols ranges: adapt4: 1-5, adapt8: 1-8, adapt16: 4-16 */
@@ -64,7 +64,9 @@
#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt4_neon
#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt8_neon
#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_neon
+#define dav1d_msac_decode_bool_adapt dav1d_msac_decode_bool_adapt_c
#define dav1d_msac_decode_bool_equi dav1d_msac_decode_bool_equi_c
+#define dav1d_msac_decode_bool dav1d_msac_decode_bool_c
#elif ARCH_X86_64 && HAVE_ASM
unsigned dav1d_msac_decode_symbol_adapt4_sse2(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
@@ -72,16 +74,22 @@
size_t n_symbols);
unsigned dav1d_msac_decode_symbol_adapt16_sse2(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
+unsigned dav1d_msac_decode_bool_adapt_sse2(MsacContext *s, uint16_t *cdf);
unsigned dav1d_msac_decode_bool_equi_sse2(MsacContext *s);
+unsigned dav1d_msac_decode_bool_sse2(MsacContext *s, unsigned f);
#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt4_sse2
#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt8_sse2
#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_sse2
+#define dav1d_msac_decode_bool_adapt dav1d_msac_decode_bool_adapt_sse2
#define dav1d_msac_decode_bool_equi dav1d_msac_decode_bool_equi_sse2
+#define dav1d_msac_decode_bool dav1d_msac_decode_bool_sse2
#else
#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt_c
#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt_c
#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt_c
+#define dav1d_msac_decode_bool_adapt dav1d_msac_decode_bool_adapt_c
#define dav1d_msac_decode_bool_equi dav1d_msac_decode_bool_equi_c
+#define dav1d_msac_decode_bool dav1d_msac_decode_bool_c
#endif
static inline unsigned dav1d_msac_decode_bools(MsacContext *const s, unsigned n) {
--- a/src/x86/msac.asm
+++ b/src/x86/msac.asm
@@ -114,6 +114,7 @@
.renorm3:
mov r1d, [sq+msac.cnt]
movifnidn t0, sq
+.renorm4:
bsr ecx, r2d
xor ecx, 15 ; d
shl r2d, cl
@@ -285,6 +286,58 @@
%endif
jmp m(msac_decode_symbol_adapt4).renorm2
+cglobal msac_decode_bool_adapt, 2, 7, 0, s, cdf
+ movzx eax, word [cdfq]
+ movzx r3d, byte [sq+msac.rng+1]
+ mov r4, [sq+msac.dif]
+ mov r2d, [sq+msac.rng]
+ mov r5d, eax
+ and eax, ~63
+ imul eax, r3d
+%if UNIX64
+ mov r7, r4
+%endif
+ shr eax, 7
+ add eax, 4 ; v
+ mov r3d, eax
+ shl rax, 48 ; vw
+ sub r2d, r3d ; r - v
+ sub r4, rax ; dif - vw
+ cmovb r2d, r3d
+ mov r3d, [sq+msac.update_cdf]
+%if UNIX64
+ cmovb r4, r7
+%else
+ cmovb r4, [sq+msac.dif]
+%endif
+ setb al
+ not r4
+ test r3d, r3d
+ jz m(msac_decode_symbol_adapt4).renorm3
+%if WIN64
+ push r7
+%endif
+ movzx r7d, word [cdfq+2]
+ movifnidn t0, sq
+ lea ecx, [r7+64]
+ cmp r7d, 32
+ adc r7d, 0
+ mov [cdfq+2], r7w
+ imul r7d, eax, -32769
+ shr ecx, 4 ; rate
+ add r7d, r5d ; if (bit)
+ sub r5d, eax ; cdf[0] -= ((cdf[0] - 32769) >> rate) + 1;
+ sar r7d, cl ; else
+ sub r5d, r7d ; cdf[0] -= cdf[0] >> rate;
+ mov [cdfq], r5w
+%if WIN64
+ mov r1d, [t0+msac.cnt]
+ pop r7
+ jmp m(msac_decode_symbol_adapt4).renorm4
+%else
+ jmp m(msac_decode_symbol_adapt4).renorm3
+%endif
+
cglobal msac_decode_bool_equi, 1, 7, 0, s
mov r1d, [sq+msac.rng]
mov r4, [sq+msac.dif]
@@ -299,6 +352,25 @@
cmovb r2d, r1d
cmovb r4, r3
setb al ; the upper 32 bits contains garbage but that's OK
+ not r4
+ jmp m(msac_decode_symbol_adapt4).renorm3
+
+cglobal msac_decode_bool, 2, 7, 0, s, f
+ movzx eax, byte [sq+msac.rng+1] ; r >> 8
+ mov r4, [sq+msac.dif]
+ mov r2d, [sq+msac.rng]
+ and r1d, ~63
+ imul eax, r1d
+ mov r3, r4
+ shr eax, 7
+ add eax, 4 ; v
+ mov r1d, eax
+ shl rax, 48 ; vw
+ sub r2d, r1d ; r - v
+ sub r4, rax ; dif - vw
+ cmovb r2d, r1d
+ cmovb r4, r3
+ setb al
not r4
jmp m(msac_decode_symbol_adapt4).renorm3
--- a/tests/checkasm/msac.c
+++ b/tests/checkasm/msac.c
@@ -37,13 +37,17 @@
/* The normal code doesn't use function pointers */
typedef unsigned (*decode_symbol_adapt_fn)(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
+typedef unsigned (*decode_bool_adapt_fn)(MsacContext *s, uint16_t *cdf);
typedef unsigned (*decode_bool_equi_fn)(MsacContext *s);
+typedef unsigned (*decode_bool_fn)(MsacContext *s, unsigned f);
typedef struct {
decode_symbol_adapt_fn symbol_adapt4;
decode_symbol_adapt_fn symbol_adapt8;
decode_symbol_adapt_fn symbol_adapt16;
+ decode_bool_adapt_fn bool_adapt;
decode_bool_equi_fn bool_equi;
+ decode_bool_fn bool;
} MsacDSPContext;
static void randomize_cdf(uint16_t *const cdf, int n) {
@@ -85,9 +89,7 @@
} \
} while (0)
-static void check_decode_symbol_adapt(MsacDSPContext *const c,
- uint8_t *const buf)
-{
+static void check_decode_symbol(MsacDSPContext *const c, uint8_t *const buf) {
/* Use an aligned CDF buffer for more consistent benchmark
* results, and a misaligned one for checking correctness. */
ALIGN_STK_16(uint16_t, cdf, 2, [17]);
@@ -97,16 +99,36 @@
CHECK_SYMBOL_ADAPT( 4, 1, 5);
CHECK_SYMBOL_ADAPT( 8, 1, 8);
CHECK_SYMBOL_ADAPT(16, 4, 16);
- report("decode_symbol_adapt");
+ report("decode_symbol");
}
-static void check_decode_bool_equi(MsacDSPContext *const c,
- uint8_t *const buf)
-{
- declare_func(unsigned, MsacContext *s);
+static void check_decode_bool(MsacDSPContext *const c, uint8_t *const buf) {
+ MsacContext s_c, s_a;
+
+ if (check_func(c->bool_adapt, "msac_decode_bool_adapt")) {
+ declare_func(unsigned, MsacContext *s, uint16_t *cdf);
+ uint16_t cdf[2][2];
+ for (int cdf_update = 0; cdf_update <= 1; cdf_update++) {
+ dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update);
+ s_a = s_c;
+ cdf[0][0] = cdf[1][0] = rnd() % 32767 + 1;
+ cdf[0][1] = cdf[1][1] = 0;
+ for (int i = 0; i < 64; i++) {
+ unsigned c_res = call_ref(&s_c, cdf[0]);
+ unsigned a_res = call_new(&s_a, cdf[1]);
+ if (c_res != a_res || msac_cmp(&s_c, &s_a) ||
+ memcmp(cdf[0], cdf[1], sizeof(*cdf)))
+ {
+ fail();
+ }
+ }
+ if (cdf_update)
+ bench_new(&s_a, cdf[0]);
+ }
+ }
if (check_func(c->bool_equi, "msac_decode_bool_equi")) {
- MsacContext s_c, s_a;
+ declare_func(unsigned, MsacContext *s);
dav1d_msac_init(&s_c, buf, BUF_SIZE, 1);
s_a = s_c;
for (int i = 0; i < 64; i++) {
@@ -118,7 +140,21 @@
bench_new(&s_a);
}
- report("decode_bool_equi");
+ if (check_func(c->bool, "msac_decode_bool")) {
+ declare_func(unsigned, MsacContext *s, unsigned f);
+ dav1d_msac_init(&s_c, buf, BUF_SIZE, 1);
+ s_a = s_c;
+ for (int i = 0; i < 64; i++) {
+ const unsigned f = rnd() & 0x7fff;
+ unsigned c_res = call_ref(&s_c, f);
+ unsigned a_res = call_new(&s_a, f);
+ if (c_res != a_res || msac_cmp(&s_c, &s_a))
+ fail();
+ }
+ bench_new(&s_a, 16384);
+ }
+
+ report("decode_bool");
}
void checkasm_check_msac(void) {
@@ -126,7 +162,9 @@
c.symbol_adapt4 = dav1d_msac_decode_symbol_adapt_c;
c.symbol_adapt8 = dav1d_msac_decode_symbol_adapt_c;
c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt_c;
+ c.bool_adapt = dav1d_msac_decode_bool_adapt_c;
c.bool_equi = dav1d_msac_decode_bool_equi_c;
+ c.bool = dav1d_msac_decode_bool_c;
#if ARCH_AARCH64 && HAVE_ASM
if (dav1d_get_cpu_flags() & DAV1D_ARM_CPU_FLAG_NEON) {
@@ -139,7 +177,9 @@
c.symbol_adapt4 = dav1d_msac_decode_symbol_adapt4_sse2;
c.symbol_adapt8 = dav1d_msac_decode_symbol_adapt8_sse2;
c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_sse2;
+ c.bool_adapt = dav1d_msac_decode_bool_adapt_sse2;
c.bool_equi = dav1d_msac_decode_bool_equi_sse2;
+ c.bool = dav1d_msac_decode_bool_sse2;
}
#endif
@@ -147,6 +187,6 @@
for (int i = 0; i < BUF_SIZE; i++)
buf[i] = rnd();
- check_decode_symbol_adapt(&c, buf);
- check_decode_bool_equi(&c, buf);
+ check_decode_symbol(&c, buf);
+ check_decode_bool(&c, buf);
}