ref: abf5c8ed64b7e59a434ed0f35babe61579242d20
parent: 1dab60cc91339f977b39db4849417483d5b3c15b
parent: d910274f790ba586c40a625bb9120d13554c3443
author: Jean-Marc Valin <[email protected]>
date: Fri Sep 19 04:02:50 EDT 2008
Merge branch 'cwrs_speedup' (derf's cwrs changes) Conflicts: libcelt/cwrs.c
--- a/libcelt/cwrs.c
+++ b/libcelt/cwrs.c
@@ -92,6 +92,75 @@
}
}
+#define MASK32 (0xFFFFFFFF)
+
+/*INV_TABLE[i] holds the multiplicative inverse of (2*i-1) mod 2**32.*/
+static const unsigned INV_TABLE[128]={
+ 0x00000001,0xAAAAAAAB,0xCCCCCCCD,0xB6DB6DB7,
+ 0x38E38E39,0xBA2E8BA3,0xC4EC4EC5,0xEEEEEEEF,
+ 0xF0F0F0F1,0x286BCA1B,0x3CF3CF3D,0xE9BD37A7,
+ 0xC28F5C29,0x684BDA13,0x4F72C235,0xBDEF7BDF,
+ 0x3E0F83E1,0x8AF8AF8B,0x914C1BAD,0x96F96F97,
+ 0xC18F9C19,0x2FA0BE83,0xA4FA4FA5,0x677D46CF,
+ 0x1A1F58D1,0xFAFAFAFB,0x8C13521D,0x586FB587,
+ 0xB823EE09,0xA08AD8F3,0xC10C9715,0xBEFBEFBF,
+ 0xC0FC0FC1,0x07A44C6B,0xA33F128D,0xE327A977,
+ 0xC7E3F1F9,0x962FC963,0x3F2B3885,0x613716AF,
+ 0x781948B1,0x2B2E43DB,0xFCFCFCFD,0x6FD0EB67,
+ 0xFA3F47E9,0xD2FD2FD3,0x3F4FD3F5,0xD4E25B9F,
+ 0x5F02A3A1,0xBF5A814B,0x7C32B16D,0xD3431B57,
+ 0xD8FD8FD9,0x8D28AC43,0xDA6C0965,0xDB195E8F,
+ 0x0FDBC091,0x61F2A4BB,0xDCFDCFDD,0x46FDD947,
+ 0x56BE69C9,0xEB2FDEB3,0x26E978D5,0xEFDFBF7F,
+ 0x0FE03F81,0xC9484E2B,0xE133F84D,0xE1A8C537,
+ 0x077975B9,0x70586723,0xCD29C245,0xFAA11E6F,
+ 0x0FE3C071,0x08B51D9B,0x8CE2CABD,0xBF937F27,
+ 0xA8FE53A9,0x592FE593,0x2C0685B5,0x2EB11B5F,
+ 0xFCD1E361,0x451AB30B,0x72CFE72D,0xDB35A717,
+ 0xFB74A399,0xE80BFA03,0x0D516325,0x1BCB564F,
+ 0xE02E4851,0xD962AE7B,0x10F8ED9D,0x95AEDD07,
+ 0xE9DC0589,0xA18A4473,0xEA53FA95,0xEE936F3F,
+ 0x90948F41,0xEAFEAFEB,0x3D137E0D,0xEF46C0F7,
+ 0x028C1979,0x791064E3,0xC04FEC05,0xE115062F,
+ 0x32385831,0x6E68575B,0xA10D387D,0x6FECF2E7,
+ 0x3FB47F69,0xED4BFB53,0x74FED775,0xDB43BB1F,
+ 0x87654321,0x9BA144CB,0x478BBCED,0xBFB912D7,
+ 0x1FDCD759,0x14B2A7C3,0xCB125CE5,0x437B2E0F,
+ 0x10FEF011,0xD2B3183B,0x386CAB5D,0xEF6AC0C7,
+ 0x0E64C149,0x9A020A33,0xE6B41C55,0xFEFEFEFF
+};
+
+/*Computes (_a*_b-_c)/(2*_d-1) when the quotient is known to be exact.
+ _a, _b, _c, and _d may be arbitrary so long as the arbitrary precision result
+ fits in 32 bits, but currently the table for multiplicative inverses is only
+ valid for _d<128.*/
+static inline celt_uint32_t imusdiv32odd(celt_uint32_t _a,celt_uint32_t _b,
+ celt_uint32_t _c,celt_uint32_t _d){
+ return (_a*_b-_c)*INV_TABLE[_d]&MASK32;
+}
+
+/*Computes (_a*_b-_c)/_d when the quotient is known to be exact.
+ _d does not actually have to be even, but imusdiv32odd will be faster when
+ it's odd, so you should use that instead.
+ _a and _d are assumed to be small (e.g., _a*_d fits in 32 bits; currently the
+ table for multiplicative inverses is only valid for _d<256).
+ _b and _c may be arbitrary so long as the arbitrary precision reuslt fits in
+ 32 bits.*/
+static inline celt_uint32_t imusdiv32even(celt_uint32_t _a,celt_uint32_t _b,
+ celt_uint32_t _c,celt_uint32_t _d){
+ unsigned inv;
+ int mask;
+ int shift;
+ int one;
+ shift=EC_ILOG(_d^_d-1);
+ inv=INV_TABLE[_d-1>>shift];
+ shift--;
+ one=1<<shift;
+ mask=one-1;
+ return (_a*(_b>>shift)-(_c>>shift)+
+ (_a*(_b&mask)+one-(_c&mask)>>shift)-1)*inv&MASK32;
+}
+
/*Computes the next row/column of any recurrence that obeys the relation
u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
_ui0 is the base case for the new row/column.*/
@@ -145,119 +214,127 @@
/*Returns the number of ways of choosing _m elements from a set of size _n with
replacement when a sign bit is needed for each unique element.
- On exit, _u will be initialized to column _m of U(n,m).*/
+ _u: On exit, _u[i] contains U(i+1,_m).*/
celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
- int k;
- CELT_MEMSET(_u,0,_n);
- if(_m<=0)return 1;
- if(_n<=0)return 0;
- for(k=1;k<_m;k++)unext32(_u,_n,2);
- return ncwrs_unext32(_n,_u);
+ celt_uint32_t ret;
+ celt_uint32_t um2;
+ int k;
+ /*If _m==0, _u[] should be set to zero and the return should be 1.*/
+ celt_assert(_m>0);
+ /*We'll overflow our buffer unless _n>=2.*/
+ celt_assert(_n>=2);
+ um2=_u[0]=1;
+ if(_m<=6){
+ if(_m<2){
+ k=1;
+ do _u[k]=1;
+ while(++k<_n);
+ }
+ else{
+ k=1;
+ do _u[k]=(k<<1)+1;
+ while(++k<_n);
+ for(k=2;k<_m;k++)unext32(_u,_n,1);
+ }
+ }
+ else{
+ celt_uint32_t um1;
+ celt_uint32_t n2m1;
+ _u[1]=n2m1=um1=(_m<<1)-1;
+ for(k=2;k<_n;k++){
+ /*U(n,m) = ((2*n-1)*U(n,m-1)-U(n,m-2))/(m-1) + U(n,m-2)*/
+ _u[k]=um2=imusdiv32even(n2m1,um1,um2,k)+um2;
+ if(++k>=_n)break;
+ _u[k]=um1=imusdiv32odd(n2m1,um2,um1,k>>1)+um1;
+ }
+ }
+ ret=1;
+ k=1;
+ do ret+=_u[k];
+ while(++k<_n);
+ return ret<<1;
}
+
/*Returns the _i'th combination of _m elements chosen from a set of size _n
with associated sign bits.
- _x: Returns the combination with elements sorted in ascending order.
- _s: Returns the associated sign bits.
- _u: Temporary storage already initialized to column _m of U(n,m).
- Its contents will be overwritten.*/
-void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_x,int *_s,celt_uint32_t *_u){
- int j;
- int k;
- for(k=j=0;k<_m;k++){
- celt_uint32_t p;
- celt_uint32_t t;
- p=_u[_n-j-1];
- if(k>0){
- t=p>>1;
- if(t<=_i||_s[k-1])_i+=t;
+ _y: Returns the vector of pulses.
+ _u: Must contain entries [1..._n] of column _m of U() on input.
+ Its contents will be destructively modified.*/
+void cwrsi32(int _n,int _m,celt_uint32_t _i,celt_uint32_t _nc,int *_y,
+ celt_uint32_t *_u){
+ celt_uint32_t p;
+ celt_uint32_t q;
+ int j;
+ int k;
+ celt_assert(_n>0);
+ p=_nc;
+ q=0;
+ j=0;
+ k=_m;
+ do{
+ int s;
+ int yj;
+ p-=q;
+ q=_u[_n-j-1];
+ p-=q;
+ s=_i>=p;
+ if(s)_i-=p;
+ yj=k;
+ while(q>_i){
+ uprev32(_u,_n-j,--k>0);
+ p=q;
+ q=_u[_n-j-1];
}
- while(p<=_i){
- _i-=p;
- j++;
- p=_u[_n-j-1];
- }
- t=p>>1;
- _s[k]=_i>=t;
- _x[k]=j;
- if(_s[k])_i-=t;
- uprev32(_u,_n-j,2);
+ _i-=q;
+ yj-=k;
+ _y[j]=yj-(yj<<1&-s);
}
+ while(++j<_n);
}
/*Returns the index of the given combination of _m elements chosen from a set
of size _n with associated sign bits.
- _x: The combination with elements sorted in ascending order.
- _s: The associated sign bits.
- _u: Temporary storage already initialized to column _m of U(n,m).
- Its contents will be overwritten.*/
-celt_uint32_t icwrs32(int _n,int _m,const int *_x,const int *_s,
+ _y: The vector of pulses, whose sum of absolute values must be _m.
+ _nc: Returns V(_n,_m).*/
+celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y,
celt_uint32_t *_u){
+ celt_uint32_t nc;
celt_uint32_t i;
int j;
int k;
- i=0;
- for(k=j=0;k<_m;k++){
- celt_uint32_t p;
- p=_u[_n-j-1];
- if(k>0)p>>=1;
- while(j<_x[k]){
- i+=p;
- j++;
- p=_u[_n-j-1];
- }
- if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
- uprev32(_u,_n-j,2);
+ /*We can't unroll the first two iterations of the loop unless _n>=2.*/
+ celt_assert(_n>=2);
+ nc=1;
+ i=_y[_n-1]<0;
+ _u[0]=0;
+ for(k=1;k<=_m+1;k++)_u[k]=(k<<1)-1;
+ k=abs(_y[_n-1]);
+ j=_n-2;
+ nc+=_u[_m];
+ i+=_u[k];
+ k+=abs(_y[j]);
+ if(_y[j]<0)i+=_u[k+1];
+ while(j-->0){
+ unext32(_u,_m+2,0);
+ nc+=_u[_m];
+ i+=_u[k];
+ k+=abs(_y[j]);
+ if(_y[j]<0)i+=_u[k+1];
}
+ /*If _m==0, nc should not be doubled.*/
+ celt_assert(_m>0);
+ *_nc=nc<<1;
return i;
}
-/*Converts a combination _x of _m unit pulses with associated sign bits _s into
- a pulse vector _y of length _n.
- _y: Returns the vector of pulses.
- _x: The combination with elements sorted in ascending order. _x[_m] = -1
- _s: The associated sign bits.*/
-void comb2pulse(int _n,int _m,int * restrict _y,const int *_x,const int *_s){
- int k;
- const int signs[2]={1,-1};
- CELT_MEMSET(_y, 0, _n);
- k=0; do {
- _y[_x[k]]+=signs[_s[k]];
- } while (++k<_m);
-}
-
-/*Converts a pulse vector vector _y of length _n into a combination of _m unit
- pulses with associated sign bits _s.
- _x: Returns the combination with elements sorted in ascending order.
- _s: Returns the associated sign bits.
- _y: The vector of pulses, whose sum of absolute values must be _m.*/
-void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
- int j;
- int k;
- for(k=j=0;j<_n;j++){
- if(_y[j]){
- int n;
- int s;
- n=abs(_y[j]);
- s=_y[j]<0;
- do {
- _x[k]=j;
- _s[k]=s;
- k++;
- } while (--n>0);
- }
- }
-}
-
-static inline void encode_comb32(int _n,int _m,const int *_x,const int *_s,
- ec_enc *_enc){
+static inline void encode_pulse32(int _n,int _m,const int *_y,ec_enc *_enc){
VARDECL(celt_uint32_t,u);
celt_uint32_t nc;
celt_uint32_t i;
SAVE_STACK;
- ALLOC(u,_n,celt_uint32_t);
- nc=ncwrs_u32(_n,_m,u);
- i=icwrs32(_n,_m,_x,_s,u);
+ ALLOC(u,_m+2,celt_uint32_t);
+ i=icwrs32(_n,_m,&nc,_y,u);
ec_enc_uint(_enc,i,nc);
RESTORE_STACK;
}
@@ -283,14 +360,6 @@
void encode_pulses(int *_y, int N, int K, ec_enc *enc)
{
- VARDECL(int, comb);
- VARDECL(int, signs);
- SAVE_STACK;
-
- ALLOC(comb, K, int);
- ALLOC(signs, K, int);
-
- pulse2comb(N, K, comb, signs, _y);
if (K==0) {
} else if (N==1)
{
@@ -297,7 +366,7 @@
ec_enc_bits(enc, _y[0]<0, 1);
} else if(fits_in32(N,K))
{
- encode_comb32(N, K, comb, signs, enc);
+ encode_pulse32(N, K, _y, enc);
} else {
int i;
int count=0;
@@ -309,25 +378,20 @@
encode_pulses(_y, split, count, enc);
encode_pulses(_y+split, N-split, K-count, enc);
}
- RESTORE_STACK;
}
-static inline void decode_comb32(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
+static inline void decode_pulse32(int _n,int _m,int *_y,ec_dec *_dec){
VARDECL(celt_uint32_t,u);
+ celt_uint32_t nc;
SAVE_STACK;
ALLOC(u,_n,celt_uint32_t);
- cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_x,_s,u);
+ nc=ncwrs_u32(_n,_m,u);
+ cwrsi32(_n,_m,ec_dec_uint(_dec,nc),nc,_y,u);
RESTORE_STACK;
}
void decode_pulses(int *_y, int N, int K, ec_dec *dec)
{
- VARDECL(int, comb);
- VARDECL(int, signs);
- SAVE_STACK;
-
- ALLOC(comb, K, int);
- ALLOC(signs, K, int);
if (K==0) {
int i;
for (i=0;i<N;i++)
@@ -341,8 +405,7 @@
_y[0] = -K;
} else if(fits_in32(N,K))
{
- decode_comb32(N, K, comb, signs, dec);
- comb2pulse(N, K, _y, comb, signs);
+ decode_pulse32(N, K, _y, dec);
} else {
int split;
int count = ec_dec_uint(dec,K+1);
@@ -350,5 +413,4 @@
decode_pulses(_y, split, count, dec);
decode_pulses(_y+split, N-split, K-count, dec);
}
- RESTORE_STACK;
}
--- a/libcelt/cwrs.h
+++ b/libcelt/cwrs.h
@@ -46,27 +46,21 @@
/* 32-bit versions */
celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u);
-void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_x,int *_s,
+void cwrsi32(int _n,int _m,celt_uint32_t _i,celt_uint32_t _nc,int *_y,
celt_uint32_t *_u);
-celt_uint32_t icwrs32(int _n,int _m,const int *_x,const int *_s,
+celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y,
celt_uint32_t *_u);
/* 64-bit versions */
celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u);
-celt_uint64_t ncwrs_unext64(int _n,celt_uint64_t *_u);
-
-void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s,
+void cwrsi64(int _n,int _m,celt_uint64_t _i,celt_uint64_t _nc,int *_y,
celt_uint64_t *_u);
-celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s,
+celt_uint64_t icwrs64(int _n,int _m,celt_uint64_t *_nc,const int *_y,
celt_uint64_t *_u);
-
-void comb2pulse(int _n,int _m,int * restrict _y,const int *_x,const int *_s);
-
-void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y);
int get_required_bits(int N, int K, int frac);
--- a/tests/cwrs32-test.c
+++ b/tests/cwrs32-test.c
@@ -21,33 +21,24 @@
inc=nc/10000;
if(inc<1)inc=1;
for(i=0;i<nc;i+=inc){
- celt_uint32_t u[NMAX];
- int x[MMAX];
- int s[MMAX];
- int x2[MMAX];
- int s2[MMAX];
+ celt_uint32_t u[NMAX>MMAX+2?NMAX:MMAX+2];
int y[NMAX];
+ celt_uint32_t v;
int k;
memcpy(u,uu,n*sizeof(*u));
- cwrsi32(n,m,i,x,s,u);
- /*printf("%6u of %u:",i,nc);*/
- /*for(k=0;k<m;k++){
- printf(" %c%i",k>0&&x[k]==x[k-1]?' ':s[k]?'-':'+',x[k]);
- }
+ cwrsi32(n,m,i,nc,y,u);
+ /*printf("%6u of %u:",i,nc);
+ for(k=0;k<n;k++)printf(" %+3i",y[k]);
printf(" ->");*/
- memcpy(u,uu,n*sizeof(*u));
- if(icwrs32(n,m,x,s,u)!=i){
+ if(icwrs32(n,m,&v,y,u)!=i){
fprintf(stderr,"Combination-index mismatch.\n");
return 1;
}
- comb2pulse(n,m,y,x,s);
- /*for(j=0;j<n;j++)printf(" %c%i",y[j]?y[j]<0?'-':'+':' ',abs(y[j]));
- printf("\n");*/
- pulse2comb(n,m,x2,s2,y);
- for(k=0;k<m;k++)if(x[k]!=x2[k]||s[k]!=s2[k]){
- fprintf(stderr,"Pulse-combination mismatch.\n");
- return 1;
+ if(v!=nc){
+ fprintf(stderr,"Combination count mismatch.\n");
+ return 2;
}
+ /*printf(" %6u\n",i);*/
}
/*printf("\n");*/
}
--- a/tests/cwrs64-test.c
+++ b/tests/cwrs64-test.c
@@ -24,33 +24,24 @@
if(inc<1)inc=1;
/*printf("%d/%d: %llu",n,m, nc);*/
for(i=0;i<nc;i+=inc){
- celt_uint64_t u[NMAX];
- int x[MMAX];
- int s[MMAX];
- int x2[MMAX];
- int s2[MMAX];
+ celt_uint64_t u[NMAX>MMAX+2?NMAX:MMAX+2];
int y[NMAX];
+ celt_uint64_t v;
int k;
memcpy(u,uu,n*sizeof(*u));
- cwrsi64(n,m,i,x,s,u);
+ cwrsi64(n,m,i,nc,y,u);
/*printf("%llu of %llu:",i,nc);
- for(k=0;k<m;k++){
- printf(" %c%i",k>0&&x[k]==x[k-1]?' ':s[k]?'-':'+',x[k]);
- }
+ for(k=0;k<n;k++)printf(" %+3i",y[k]);
printf(" ->");*/
- memcpy(u,uu,n*sizeof(*u));
- if(icwrs64(n,m,x,s,u)!=i){
+ if(icwrs64(n,m,&v,y,u)!=i){
fprintf(stderr,"Combination-index mismatch.\n");
return 1;
}
- comb2pulse(n,m,y,x,s);
- /*for(j=0;j<n;j++)printf(" %c%i",y[j]?y[j]<0?'-':'+':' ',abs(y[j]));
- printf("\n");*/
- pulse2comb(n,m,x2,s2,y);
- for(k=0;k<m;k++)if(x[k]!=x2[k]||s[k]!=s2[k]){
- fprintf(stderr,"Pulse-combination mismatch.\n");
- return 1;
+ if(v!=nc){
+ fprintf(stderr,"Combination count mismatch.\n");
+ return 2;
}
+ /*printf(" %6llu\n",i);*/
}
/*printf("\n");*/
}