ref: b20a2d63f21ed0c26d0527fed3eba9b6035842d7
parent: 30d5f4862889a8b336cd4b391f58482b5b40b196
author: Henrik Gramner <[email protected]>
date: Mon May 13 21:01:51 EDT 2019
x86-64: Add msac_decode_bool_equi asm
--- a/src/msac.c
+++ b/src/msac.c
@@ -27,7 +27,6 @@
#include "config.h"
-#include <assert.h>
#include <limits.h>
#include "common/intops.h"
@@ -68,7 +67,7 @@
ctx_refill(s);
}
-unsigned dav1d_msac_decode_bool_equi(MsacContext *const s) {
+unsigned dav1d_msac_decode_bool_equi_c(MsacContext *const s) {
ec_win vw, dif = s->dif;
unsigned ret, v, r = s->rng;
assert((dif >> (EC_WIN_SIZE - 16)) < r);
@@ -99,13 +98,6 @@
return !ret;
}
-unsigned dav1d_msac_decode_bools(MsacContext *const s, unsigned n) {
- unsigned v = 0;
- while (n--)
- v = (v << 1) | dav1d_msac_decode_bool_equi(s);
- return v;
-}
-
int dav1d_msac_decode_subexp(MsacContext *const s, const int ref,
const int n, const unsigned k)
{
@@ -120,15 +112,6 @@
const unsigned v = dav1d_msac_decode_bools(s, b) + a;
return ref * 2 <= n ? inv_recenter(ref, v) :
n - 1 - inv_recenter(n - 1 - ref, v);
-}
-
-int dav1d_msac_decode_uniform(MsacContext *const s, const unsigned n) {
- assert(n > 0);
- const int l = ulog2(n) + 1;
- assert(l > 1);
- const unsigned m = (1 << l) - n;
- const unsigned v = dav1d_msac_decode_bools(s, l - 1);
- return v < m ? v : (v << 1) - m + dav1d_msac_decode_bool_equi(s);
}
/* Decodes a symbol given an inverse cumulative distribution function (CDF)
--- a/src/msac.h
+++ b/src/msac.h
@@ -28,6 +28,7 @@
#ifndef DAV1D_SRC_MSAC_H
#define DAV1D_SRC_MSAC_H
+#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
@@ -47,12 +48,10 @@
int disable_cdf_update_flag);
unsigned dav1d_msac_decode_symbol_adapt_c(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
-unsigned dav1d_msac_decode_bool_equi(MsacContext *s);
+unsigned dav1d_msac_decode_bool_equi_c(MsacContext *s);
unsigned dav1d_msac_decode_bool(MsacContext *s, unsigned f);
unsigned dav1d_msac_decode_bool_adapt(MsacContext *s, uint16_t *cdf);
-unsigned dav1d_msac_decode_bools(MsacContext *s, unsigned n);
int dav1d_msac_decode_subexp(MsacContext *s, int ref, int n, unsigned k);
-int dav1d_msac_decode_uniform(MsacContext *s, unsigned n);
/* Supported n_symbols ranges: adapt4: 1-5, adapt8: 1-8, adapt16: 4-16 */
#if ARCH_AARCH64 && HAVE_ASM
@@ -65,6 +64,7 @@
#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt4_neon
#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt8_neon
#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_neon
+#define dav1d_msac_decode_bool_equi dav1d_msac_decode_bool_equi_c
#elif ARCH_X86_64 && HAVE_ASM
unsigned dav1d_msac_decode_symbol_adapt4_sse2(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
@@ -72,13 +72,32 @@
size_t n_symbols);
unsigned dav1d_msac_decode_symbol_adapt16_sse2(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
+unsigned dav1d_msac_decode_bool_equi_sse2(MsacContext *s);
#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt4_sse2
#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt8_sse2
#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_sse2
+#define dav1d_msac_decode_bool_equi dav1d_msac_decode_bool_equi_sse2
#else
#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt_c
#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt_c
#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt_c
+#define dav1d_msac_decode_bool_equi dav1d_msac_decode_bool_equi_c
#endif
+
+static inline unsigned dav1d_msac_decode_bools(MsacContext *const s, unsigned n) {
+ unsigned v = 0;
+ while (n--)
+ v = (v << 1) | dav1d_msac_decode_bool_equi(s);
+ return v;
+}
+
+static inline int dav1d_msac_decode_uniform(MsacContext *const s, const unsigned n) {
+ assert(n > 0);
+ const int l = ulog2(n) + 1;
+ assert(l > 1);
+ const unsigned m = (1 << l) - n;
+ const unsigned v = dav1d_msac_decode_bools(s, l - 1);
+ return v < m ? v : (v << 1) - m + dav1d_msac_decode_bool_equi(s);
+}
#endif /* DAV1D_SRC_MSAC_H */
--- a/src/x86/msac.asm
+++ b/src/x86/msac.asm
@@ -111,6 +111,7 @@
sub r2d, r1d ; rng
shl r1, 48
add r4, r1 ; ~dif
+.renorm3:
mov r1d, [sq+msac.cnt]
movifnidn t0, sq
bsr ecx, r2d
@@ -283,5 +284,22 @@
add rsp, 48
%endif
jmp m(msac_decode_symbol_adapt4).renorm2
+
+cglobal msac_decode_bool_equi, 1, 7, 0, s
+ mov r1d, [sq+msac.rng]
+ mov r4, [sq+msac.dif]
+ mov r2d, r1d
+ mov r1b, 8
+ mov r3, r4
+ mov eax, r1d
+ shr r1d, 1 ; v
+ shl rax, 47 ; vw
+ sub r2d, r1d ; r - v
+ sub r4, rax ; dif - vw
+ cmovb r2d, r1d
+ cmovb r4, r3
+ setb al ; the upper 32 bits contains garbage but that's OK
+ not r4
+ jmp m(msac_decode_symbol_adapt4).renorm3
%endif
--- a/tests/checkasm/msac.c
+++ b/tests/checkasm/msac.c
@@ -32,14 +32,18 @@
#include <string.h>
+#define BUF_SIZE 8192
+
/* The normal code doesn't use function pointers */
typedef unsigned (*decode_symbol_adapt_fn)(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
+typedef unsigned (*decode_bool_equi_fn)(MsacContext *s);
typedef struct {
decode_symbol_adapt_fn symbol_adapt4;
decode_symbol_adapt_fn symbol_adapt8;
decode_symbol_adapt_fn symbol_adapt16;
+ decode_bool_equi_fn bool_equi;
} MsacDSPContext;
static void randomize_cdf(uint16_t *const cdf, int n) {
@@ -61,7 +65,7 @@
if (check_func(c->symbol_adapt##n, "msac_decode_symbol_adapt%d", n)) { \
for (int cdf_update = 0; cdf_update <= 1; cdf_update++) { \
for (int ns = n_min; ns <= n_max; ns++) { \
- dav1d_msac_init(&s_c, buf, sizeof(buf), !cdf_update); \
+ dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update); \
s_a = s_c; \
randomize_cdf(cdf[0], ns); \
memcpy(cdf[1], cdf[0], sizeof(*cdf)); \
@@ -81,14 +85,13 @@
} \
} while (0)
-static void check_decode_symbol_adapt(MsacDSPContext *const c) {
+static void check_decode_symbol_adapt(MsacDSPContext *const c,
+ uint8_t *const buf)
+{
/* Use an aligned CDF buffer for more consistent benchmark
* results, and a misaligned one for checking correctness. */
ALIGN_STK_16(uint16_t, cdf, 2, [17]);
MsacContext s_c, s_a;
- uint8_t buf[1024];
- for (int i = 0; i < 1024; i++)
- buf[i] = rnd();
declare_func(unsigned, MsacContext *s, uint16_t *cdf, size_t n_symbols);
CHECK_SYMBOL_ADAPT( 4, 1, 5);
@@ -97,11 +100,33 @@
report("decode_symbol_adapt");
}
+static void check_decode_bool_equi(MsacDSPContext *const c,
+ uint8_t *const buf)
+{
+ declare_func(unsigned, MsacContext *s);
+
+ if (check_func(c->bool_equi, "msac_decode_bool_equi")) {
+ MsacContext s_c, s_a;
+ dav1d_msac_init(&s_c, buf, BUF_SIZE, 1);
+ s_a = s_c;
+ for (int i = 0; i < 64; i++) {
+ unsigned c_res = call_ref(&s_c);
+ unsigned a_res = call_new(&s_a);
+ if (c_res != a_res || msac_cmp(&s_c, &s_a))
+ fail();
+ }
+ bench_new(&s_a);
+ }
+
+ report("decode_bool_equi");
+}
+
void checkasm_check_msac(void) {
MsacDSPContext c;
c.symbol_adapt4 = dav1d_msac_decode_symbol_adapt_c;
c.symbol_adapt8 = dav1d_msac_decode_symbol_adapt_c;
c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt_c;
+ c.bool_equi = dav1d_msac_decode_bool_equi_c;
#if ARCH_AARCH64 && HAVE_ASM
if (dav1d_get_cpu_flags() & DAV1D_ARM_CPU_FLAG_NEON) {
@@ -114,8 +139,14 @@
c.symbol_adapt4 = dav1d_msac_decode_symbol_adapt4_sse2;
c.symbol_adapt8 = dav1d_msac_decode_symbol_adapt8_sse2;
c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_sse2;
+ c.bool_equi = dav1d_msac_decode_bool_equi_sse2;
}
#endif
- check_decode_symbol_adapt(&c);
+ uint8_t buf[BUF_SIZE];
+ for (int i = 0; i < BUF_SIZE; i++)
+ buf[i] = rnd();
+
+ check_decode_symbol_adapt(&c, buf);
+ check_decode_bool_equi(&c, buf);
}