ref: 3c8110a947e28d7a82087aadc520cd25ac720fe4
parent: 35ab85bbbc6f8270b5ba973cf8452cb5aa2b32eb
author: Lynne <[email protected]>
date: Mon Jan 13 06:06:05 EST 2020
x86/msac: add an avx2 version for msac_decode_symbol_adapt16 msac_decode_symbol_adapt16_c: 55.1 msac_decode_symbol_adapt16_sse2: 30.3 msac_decode_symbol_adapt16_avx2: 28.0 Most code written by Henrik Gramner.
--- a/src/x86/msac.asm
+++ b/src/x86/msac.asm
@@ -67,7 +67,7 @@
.update_cdf: resd 1
endstruc
-%define m(x) mangle(private_prefix %+ _ %+ x %+ SUFFIX)
+%define m(x, y) mangle(private_prefix %+ _ %+ x %+ y)
SECTION .text
@@ -240,7 +240,7 @@
pcmpeqw m1, m2
pmovmskb eax, m1
test t3d, t3d
- jz m(msac_decode_symbol_adapt4).renorm
+ jz m(msac_decode_symbol_adapt4, SUFFIX).renorm
movzx t3d, word [t1+t4*2]
pcmpeqw m2, m2
mov t2d, t3d
@@ -257,7 +257,7 @@
paddw m0, m2
mova [t1], m0
mov [t1+t4*2], t2w
- jmp m(msac_decode_symbol_adapt4).renorm
+ jmp m(msac_decode_symbol_adapt4, SUFFIX).renorm
cglobal msac_decode_symbol_adapt16, 0, 6, 6
DECODE_SYMBOL_ADAPT_INIT
@@ -330,7 +330,7 @@
%if WIN64
add rsp, 48
%endif
- jmp m(msac_decode_symbol_adapt4).renorm2
+ jmp m(msac_decode_symbol_adapt4, SUFFIX).renorm2
cglobal msac_decode_bool_adapt, 0, 6, 0
movifnidn t1, r1mp
@@ -366,7 +366,7 @@
%endif
not t4
test t3d, t3d
- jz m(msac_decode_symbol_adapt4).renorm3
+ jz m(msac_decode_symbol_adapt4, SUFFIX).renorm3
%if UNIX64 == 0
push t6
%endif
@@ -390,13 +390,13 @@
%if WIN64
mov t1d, [t7+msac.cnt]
pop t6
- jmp m(msac_decode_symbol_adapt4).renorm4
+ jmp m(msac_decode_symbol_adapt4, SUFFIX).renorm4
%else
%if ARCH_X86_64 == 0
pop t5
pop t6
%endif
- jmp m(msac_decode_symbol_adapt4).renorm3
+ jmp m(msac_decode_symbol_adapt4, SUFFIX).renorm3
%endif
cglobal msac_decode_bool_equi, 0, 6, 0
@@ -418,7 +418,7 @@
%if ARCH_X86_64 == 0
movzx eax, al
%endif
- jmp m(msac_decode_symbol_adapt4).renorm3
+ jmp m(msac_decode_symbol_adapt4, SUFFIX).renorm3
cglobal msac_decode_bool, 0, 6, 0
movifnidn t0, r0mp
@@ -442,7 +442,7 @@
%if ARCH_X86_64 == 0
movzx eax, al
%endif
- jmp m(msac_decode_symbol_adapt4).renorm3
+ jmp m(msac_decode_symbol_adapt4, SUFFIX).renorm3
%macro HI_TOK 1 ; update_cdf
%if ARCH_X86_64 == 0
@@ -598,3 +598,71 @@
HI_TOK 1
.no_update_cdf:
HI_TOK 0
+
+%if ARCH_X86_64
+INIT_YMM avx2
+cglobal msac_decode_symbol_adapt16, 3, 6, 6
+ lea rax, [pw_0xff00]
+ vpbroadcastw m2, [t0+msac.rng]
+ mova m0, [t1]
+ vpbroadcastw m3, [t0+msac.dif+6]
+ vbroadcasti128 m4, [rax]
+ mov t3d, [t0+msac.update_cdf]
+ mov t4d, t2d
+ not t2
+%if STACK_ALIGNMENT < 32
+ mov r5, rsp
+%if WIN64
+ and rsp, ~31
+ sub rsp, 40
+%else
+ and r5, ~31
+ %define buf r5-32
+%endif
+%elif WIN64
+ sub rsp, 64
+%else
+ %define buf rsp-56
+%endif
+ psrlw m1, m0, 6
+ movd [buf-4], xm2
+ pand m2, m4
+ psllw m1, 7
+ pmulhuw m1, m2
+ paddw m1, [rax+t2*2]
+ mova [buf], m1
+ pmaxuw m1, m3
+ pcmpeqw m1, m3
+ pmovmskb eax, m1
+ test t3d, t3d
+ jz .renorm
+ movzx t3d, word [t1+t4*2]
+ pcmpeqw m2, m2
+ lea t2d, [t3+80]
+ shr t2d, 4
+ cmp t3d, 32
+ adc t3d, 0
+ movd xm3, t2d
+ pavgw m2, m1
+ psubw m2, m0
+ psubw m0, m1
+ psraw m2, xm3
+ paddw m0, m2
+ mova [t1], m0
+ mov [t1+t4*2], t3w
+.renorm:
+ tzcnt eax, eax
+ mov t4, [t0+msac.dif]
+ movzx t1d, word [buf+rax-0]
+ movzx t2d, word [buf+rax-2]
+ shr eax, 1
+%if WIN64
+%if STACK_ALIGNMENT < 32
+ mov rsp, r5
+%else
+ add rsp, 64
+%endif
+%endif
+ vzeroupper
+ jmp m(msac_decode_symbol_adapt4, _sse2).renorm2
+%endif
--- a/src/x86/msac.h
+++ b/src/x86/msac.h
@@ -39,6 +39,10 @@
unsigned dav1d_msac_decode_bool_sse2(MsacContext *s, unsigned f);
unsigned dav1d_msac_decode_hi_tok_sse2(MsacContext *s, uint16_t *cdf);
+/* Needed for checkasm */
+unsigned dav1d_msac_decode_symbol_adapt16_avx2(MsacContext *s, uint16_t *cdf,
+ size_t n_symbols);
+
#if ARCH_X86_64 || defined(__SSE2__) || (defined(_M_IX86_FP) && _M_IX86_FP >= 2)
#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt4_sse2
#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt8_sse2
@@ -49,7 +53,9 @@
#define dav1d_msac_decode_bool_equi dav1d_msac_decode_bool_equi_sse2
#define dav1d_msac_decode_bool dav1d_msac_decode_bool_sse2
-#if defined(__SSE2__) || (defined(_M_IX86_FP) && _M_IX86_FP >= 2)
+#if ARCH_X86_64
+#define dav1d_msac_decode_symbol_adapt16(ctx, cdf, symb) ((ctx)->symbol_adapt16(ctx, cdf, symb))
+#elif defined(__SSE2__) || (defined(_M_IX86_FP) && _M_IX86_FP >= 2)
#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_sse2
#endif
--- a/src/x86/msac_init.c
+++ b/src/x86/msac_init.c
@@ -28,14 +28,15 @@
#include "src/msac.h"
#include "src/x86/msac.h"
-unsigned dav1d_msac_decode_symbol_adapt16_avx2(MsacContext *s, uint16_t *cdf,
- size_t n_symbols);
-
void dav1d_msac_init_x86(MsacContext *const s) {
const unsigned flags = dav1d_get_cpu_flags();
if (flags & DAV1D_X86_CPU_FLAG_SSE2) {
s->symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_sse2;
+ }
+
+ if (flags & DAV1D_X86_CPU_FLAG_AVX2) {
+ s->symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_avx2;
}
}
--- a/tests/checkasm/msac.c
+++ b/tests/checkasm/msac.c
@@ -258,6 +258,12 @@
c.bool = dav1d_msac_decode_bool_sse2;
c.hi_tok = dav1d_msac_decode_hi_tok_sse2;
}
+
+#if ARCH_X86_64
+ if (dav1d_get_cpu_flags() & DAV1D_X86_CPU_FLAG_AVX2) {
+ c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_avx2;
+ }
+#endif
#endif
uint8_t buf[BUF_SIZE];