shithub: opus

Download patch

ref: abf5c8ed64b7e59a434ed0f35babe61579242d20
parent: 1dab60cc91339f977b39db4849417483d5b3c15b
parent: d910274f790ba586c40a625bb9120d13554c3443
author: Jean-Marc Valin <[email protected]>
date: Fri Sep 19 04:02:50 EDT 2008

Merge branch 'cwrs_speedup' (derf's cwrs changes)

Conflicts:
	libcelt/cwrs.c

--- a/libcelt/cwrs.c
+++ b/libcelt/cwrs.c
@@ -92,6 +92,75 @@
    }   
 }
 
+#define MASK32 (0xFFFFFFFF)
+
+/*INV_TABLE[i] holds the multiplicative inverse of (2*i-1) mod 2**32.*/
+static const unsigned INV_TABLE[128]={
+  0x00000001,0xAAAAAAAB,0xCCCCCCCD,0xB6DB6DB7,
+  0x38E38E39,0xBA2E8BA3,0xC4EC4EC5,0xEEEEEEEF,
+  0xF0F0F0F1,0x286BCA1B,0x3CF3CF3D,0xE9BD37A7,
+  0xC28F5C29,0x684BDA13,0x4F72C235,0xBDEF7BDF,
+  0x3E0F83E1,0x8AF8AF8B,0x914C1BAD,0x96F96F97,
+  0xC18F9C19,0x2FA0BE83,0xA4FA4FA5,0x677D46CF,
+  0x1A1F58D1,0xFAFAFAFB,0x8C13521D,0x586FB587,
+  0xB823EE09,0xA08AD8F3,0xC10C9715,0xBEFBEFBF,
+  0xC0FC0FC1,0x07A44C6B,0xA33F128D,0xE327A977,
+  0xC7E3F1F9,0x962FC963,0x3F2B3885,0x613716AF,
+  0x781948B1,0x2B2E43DB,0xFCFCFCFD,0x6FD0EB67,
+  0xFA3F47E9,0xD2FD2FD3,0x3F4FD3F5,0xD4E25B9F,
+  0x5F02A3A1,0xBF5A814B,0x7C32B16D,0xD3431B57,
+  0xD8FD8FD9,0x8D28AC43,0xDA6C0965,0xDB195E8F,
+  0x0FDBC091,0x61F2A4BB,0xDCFDCFDD,0x46FDD947,
+  0x56BE69C9,0xEB2FDEB3,0x26E978D5,0xEFDFBF7F,
+  0x0FE03F81,0xC9484E2B,0xE133F84D,0xE1A8C537,
+  0x077975B9,0x70586723,0xCD29C245,0xFAA11E6F,
+  0x0FE3C071,0x08B51D9B,0x8CE2CABD,0xBF937F27,
+  0xA8FE53A9,0x592FE593,0x2C0685B5,0x2EB11B5F,
+  0xFCD1E361,0x451AB30B,0x72CFE72D,0xDB35A717,
+  0xFB74A399,0xE80BFA03,0x0D516325,0x1BCB564F,
+  0xE02E4851,0xD962AE7B,0x10F8ED9D,0x95AEDD07,
+  0xE9DC0589,0xA18A4473,0xEA53FA95,0xEE936F3F,
+  0x90948F41,0xEAFEAFEB,0x3D137E0D,0xEF46C0F7,
+  0x028C1979,0x791064E3,0xC04FEC05,0xE115062F,
+  0x32385831,0x6E68575B,0xA10D387D,0x6FECF2E7,
+  0x3FB47F69,0xED4BFB53,0x74FED775,0xDB43BB1F,
+  0x87654321,0x9BA144CB,0x478BBCED,0xBFB912D7,
+  0x1FDCD759,0x14B2A7C3,0xCB125CE5,0x437B2E0F,
+  0x10FEF011,0xD2B3183B,0x386CAB5D,0xEF6AC0C7,
+  0x0E64C149,0x9A020A33,0xE6B41C55,0xFEFEFEFF
+};
+
+/*Computes (_a*_b-_c)/(2*_d-1) when the quotient is known to be exact.
+  _a, _b, _c, and _d may be arbitrary so long as the arbitrary precision result
+   fits in 32 bits, but currently the table for multiplicative inverses is only
+   valid for _d<128.*/
+static inline celt_uint32_t imusdiv32odd(celt_uint32_t _a,celt_uint32_t _b,
+ celt_uint32_t _c,celt_uint32_t _d){
+  return (_a*_b-_c)*INV_TABLE[_d]&MASK32;
+}
+
+/*Computes (_a*_b-_c)/_d when the quotient is known to be exact.
+  _d does not actually have to be even, but imusdiv32odd will be faster when
+   it's odd, so you should use that instead.
+  _a and _d are assumed to be small (e.g., _a*_d fits in 32 bits; currently the
+   table for multiplicative inverses is only valid for _d<256).
+  _b and _c may be arbitrary so long as the arbitrary precision reuslt fits in
+   32 bits.*/
+static inline celt_uint32_t imusdiv32even(celt_uint32_t _a,celt_uint32_t _b,
+ celt_uint32_t _c,celt_uint32_t _d){
+  unsigned inv;
+  int      mask;
+  int      shift;
+  int      one;
+  shift=EC_ILOG(_d^_d-1);
+  inv=INV_TABLE[_d-1>>shift];
+  shift--;
+  one=1<<shift;
+  mask=one-1;
+  return (_a*(_b>>shift)-(_c>>shift)+
+   (_a*(_b&mask)+one-(_c&mask)>>shift)-1)*inv&MASK32;
+}
+
 /*Computes the next row/column of any recurrence that obeys the relation
    u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
   _ui0 is the base case for the new row/column.*/
@@ -145,119 +214,127 @@
 
 /*Returns the number of ways of choosing _m elements from a set of size _n with
    replacement when a sign bit is needed for each unique element.
-  On exit, _u will be initialized to column _m of U(n,m).*/
+  _u: On exit, _u[i] contains U(i+1,_m).*/
 celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
-  int k;
-  CELT_MEMSET(_u,0,_n);
-  if(_m<=0)return 1;
-  if(_n<=0)return 0;
-  for(k=1;k<_m;k++)unext32(_u,_n,2);
-  return ncwrs_unext32(_n,_u);
+  celt_uint32_t ret;
+  celt_uint32_t um2;
+  int           k;
+  /*If _m==0, _u[] should be set to zero and the return should be 1.*/
+  celt_assert(_m>0);
+  /*We'll overflow our buffer unless _n>=2.*/
+  celt_assert(_n>=2);
+  um2=_u[0]=1;
+  if(_m<=6){
+    if(_m<2){
+      k=1;
+      do _u[k]=1;
+      while(++k<_n);
+    }
+    else{
+      k=1;
+      do _u[k]=(k<<1)+1;
+      while(++k<_n);
+      for(k=2;k<_m;k++)unext32(_u,_n,1);
+    }
+  }
+  else{
+    celt_uint32_t um1;
+    celt_uint32_t n2m1;
+    _u[1]=n2m1=um1=(_m<<1)-1;
+    for(k=2;k<_n;k++){
+      /*U(n,m) = ((2*n-1)*U(n,m-1)-U(n,m-2))/(m-1) + U(n,m-2)*/
+      _u[k]=um2=imusdiv32even(n2m1,um1,um2,k)+um2;
+      if(++k>=_n)break;
+      _u[k]=um1=imusdiv32odd(n2m1,um2,um1,k>>1)+um1;
+    }
+  }
+  ret=1;
+  k=1;
+  do ret+=_u[k];
+  while(++k<_n);
+  return ret<<1;
 }
 
+
 /*Returns the _i'th combination of _m elements chosen from a set of size _n
    with associated sign bits.
-  _x: Returns the combination with elements sorted in ascending order.
-  _s: Returns the associated sign bits.
-  _u: Temporary storage already initialized to column _m of U(n,m).
-      Its contents will be overwritten.*/
-void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_x,int *_s,celt_uint32_t *_u){
-  int j;
-  int k;
-  for(k=j=0;k<_m;k++){
-    celt_uint32_t p;
-    celt_uint32_t t;
-    p=_u[_n-j-1];
-    if(k>0){
-      t=p>>1;
-      if(t<=_i||_s[k-1])_i+=t;
+  _y: Returns the vector of pulses.
+  _u: Must contain entries [1..._n] of column _m of U() on input.
+      Its contents will be destructively modified.*/
+void cwrsi32(int _n,int _m,celt_uint32_t _i,celt_uint32_t _nc,int *_y,
+ celt_uint32_t *_u){
+  celt_uint32_t p;
+  celt_uint32_t q;
+  int           j;
+  int           k;
+  celt_assert(_n>0);
+  p=_nc;
+  q=0;
+  j=0;
+  k=_m;
+  do{
+    int s;
+    int yj;
+    p-=q;
+    q=_u[_n-j-1];
+    p-=q;
+    s=_i>=p;
+    if(s)_i-=p;
+    yj=k;
+    while(q>_i){
+      uprev32(_u,_n-j,--k>0);
+      p=q;
+      q=_u[_n-j-1];
     }
-    while(p<=_i){
-      _i-=p;
-      j++;
-      p=_u[_n-j-1];
-    }
-    t=p>>1;
-    _s[k]=_i>=t;
-    _x[k]=j;
-    if(_s[k])_i-=t;
-    uprev32(_u,_n-j,2);
+    _i-=q;
+    yj-=k;
+    _y[j]=yj-(yj<<1&-s);
   }
+  while(++j<_n);
 }
 
 /*Returns the index of the given combination of _m elements chosen from a set
    of size _n with associated sign bits.
-  _x: The combination with elements sorted in ascending order.
-  _s: The associated sign bits.
-  _u: Temporary storage already initialized to column _m of U(n,m).
-      Its contents will be overwritten.*/
-celt_uint32_t icwrs32(int _n,int _m,const int *_x,const int *_s,
+  _y:  The vector of pulses, whose sum of absolute values must be _m.
+  _nc: Returns V(_n,_m).*/
+celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y,
  celt_uint32_t *_u){
+  celt_uint32_t nc;
   celt_uint32_t i;
   int           j;
   int           k;
-  i=0;
-  for(k=j=0;k<_m;k++){
-    celt_uint32_t p;
-    p=_u[_n-j-1];
-    if(k>0)p>>=1;
-    while(j<_x[k]){
-      i+=p;
-      j++;
-      p=_u[_n-j-1];
-    }
-    if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
-    uprev32(_u,_n-j,2);
+  /*We can't unroll the first two iterations of the loop unless _n>=2.*/
+  celt_assert(_n>=2);
+  nc=1;
+  i=_y[_n-1]<0;
+  _u[0]=0;
+  for(k=1;k<=_m+1;k++)_u[k]=(k<<1)-1;
+  k=abs(_y[_n-1]);
+  j=_n-2;
+  nc+=_u[_m];
+  i+=_u[k];
+  k+=abs(_y[j]);
+  if(_y[j]<0)i+=_u[k+1];
+  while(j-->0){
+    unext32(_u,_m+2,0);
+    nc+=_u[_m];
+    i+=_u[k];
+    k+=abs(_y[j]);
+    if(_y[j]<0)i+=_u[k+1];
   }
+  /*If _m==0, nc should not be doubled.*/
+  celt_assert(_m>0);
+  *_nc=nc<<1;
   return i;
 }
 
-/*Converts a combination _x of _m unit pulses with associated sign bits _s into
-   a pulse vector _y of length _n.
-  _y: Returns the vector of pulses.
-  _x: The combination with elements sorted in ascending order. _x[_m] = -1
-  _s: The associated sign bits.*/
-void comb2pulse(int _n,int _m,int * restrict _y,const int *_x,const int *_s){
-  int k;
-  const int signs[2]={1,-1};
-  CELT_MEMSET(_y, 0, _n);
-  k=0; do {
-    _y[_x[k]]+=signs[_s[k]];
-  } while (++k<_m);
-}
-
-/*Converts a pulse vector vector _y of length _n into a combination of _m unit
-   pulses with associated sign bits _s.
-  _x: Returns the combination with elements sorted in ascending order.
-  _s: Returns the associated sign bits.
-  _y: The vector of pulses, whose sum of absolute values must be _m.*/
-void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
-  int j;
-  int k;
-  for(k=j=0;j<_n;j++){
-    if(_y[j]){
-      int n;
-      int s;
-      n=abs(_y[j]);
-      s=_y[j]<0;
-      do {
-        _x[k]=j;
-        _s[k]=s;
-        k++;
-      } while (--n>0);
-    }
-  }
-}
-
-static inline void encode_comb32(int _n,int _m,const int *_x,const int *_s,
- ec_enc *_enc){
+static inline void encode_pulse32(int _n,int _m,const int *_y,ec_enc *_enc){
   VARDECL(celt_uint32_t,u);
   celt_uint32_t nc;
   celt_uint32_t i;
   SAVE_STACK;
-  ALLOC(u,_n,celt_uint32_t);
-  nc=ncwrs_u32(_n,_m,u);
-  i=icwrs32(_n,_m,_x,_s,u);
+  ALLOC(u,_m+2,celt_uint32_t);
+  i=icwrs32(_n,_m,&nc,_y,u);
   ec_enc_uint(_enc,i,nc);
   RESTORE_STACK;
 }
@@ -283,14 +360,6 @@
 
 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
 {
-   VARDECL(int, comb);
-   VARDECL(int, signs);
-   SAVE_STACK;
-
-   ALLOC(comb, K, int);
-   ALLOC(signs, K, int);
-
-   pulse2comb(N, K, comb, signs, _y);
    if (K==0) {
    } else if (N==1)
    {
@@ -297,7 +366,7 @@
       ec_enc_bits(enc, _y[0]<0, 1);
    } else if(fits_in32(N,K))
    {
-      encode_comb32(N, K, comb, signs, enc);
+      encode_pulse32(N, K, _y, enc);
    } else {
      int i;
      int count=0;
@@ -309,25 +378,20 @@
      encode_pulses(_y, split, count, enc);
      encode_pulses(_y+split, N-split, K-count, enc);
    }
-   RESTORE_STACK;
 }
 
-static inline void decode_comb32(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
+static inline void decode_pulse32(int _n,int _m,int *_y,ec_dec *_dec){
   VARDECL(celt_uint32_t,u);
+  celt_uint32_t nc;
   SAVE_STACK;
   ALLOC(u,_n,celt_uint32_t);
-  cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_x,_s,u);
+  nc=ncwrs_u32(_n,_m,u);
+  cwrsi32(_n,_m,ec_dec_uint(_dec,nc),nc,_y,u);
   RESTORE_STACK;
 }
 
 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
 {
-   VARDECL(int, comb);
-   VARDECL(int, signs);
-   SAVE_STACK;
-
-   ALLOC(comb, K, int);
-   ALLOC(signs, K, int);
    if (K==0) {
       int i;
       for (i=0;i<N;i++)
@@ -341,8 +405,7 @@
          _y[0] = -K;
    } else if(fits_in32(N,K))
    {
-      decode_comb32(N, K, comb, signs, dec);
-      comb2pulse(N, K, _y, comb, signs);
+      decode_pulse32(N, K, _y, dec);
    } else {
      int split;
      int count = ec_dec_uint(dec,K+1);
@@ -350,5 +413,4 @@
      decode_pulses(_y, split, count, dec);
      decode_pulses(_y+split, N-split, K-count, dec);
    }
-   RESTORE_STACK;
 }
--- a/libcelt/cwrs.h
+++ b/libcelt/cwrs.h
@@ -46,27 +46,21 @@
 /* 32-bit versions */
 celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u);
 
-void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_x,int *_s,
+void cwrsi32(int _n,int _m,celt_uint32_t _i,celt_uint32_t _nc,int *_y,
  celt_uint32_t *_u);
 
-celt_uint32_t icwrs32(int _n,int _m,const int *_x,const int *_s,
+celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y,
  celt_uint32_t *_u);
 
 /* 64-bit versions */
 celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u);
 
-celt_uint64_t ncwrs_unext64(int _n,celt_uint64_t *_u);
-
-void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s,
+void cwrsi64(int _n,int _m,celt_uint64_t _i,celt_uint64_t _nc,int *_y,
  celt_uint64_t *_u);
 
-celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s,
+celt_uint64_t icwrs64(int _n,int _m,celt_uint64_t *_nc,const int *_y,
  celt_uint64_t *_u);
 
-
-void comb2pulse(int _n,int _m,int * restrict _y,const int *_x,const int *_s);
-
-void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y);
 
 int get_required_bits(int N, int K, int frac);
 
--- a/tests/cwrs32-test.c
+++ b/tests/cwrs32-test.c
@@ -21,33 +21,24 @@
       inc=nc/10000;
       if(inc<1)inc=1;
       for(i=0;i<nc;i+=inc){
-        celt_uint32_t u[NMAX];
-        int           x[MMAX];
-        int           s[MMAX];
-        int           x2[MMAX];
-        int           s2[MMAX];
+        celt_uint32_t u[NMAX>MMAX+2?NMAX:MMAX+2];
         int           y[NMAX];
+        celt_uint32_t v;
         int           k;
         memcpy(u,uu,n*sizeof(*u));
-        cwrsi32(n,m,i,x,s,u);
-        /*printf("%6u of %u:",i,nc);*/
-        /*for(k=0;k<m;k++){
-          printf(" %c%i",k>0&&x[k]==x[k-1]?' ':s[k]?'-':'+',x[k]);
-        }
+        cwrsi32(n,m,i,nc,y,u);
+        /*printf("%6u of %u:",i,nc);
+        for(k=0;k<n;k++)printf(" %+3i",y[k]);
         printf(" ->");*/
-        memcpy(u,uu,n*sizeof(*u));
-        if(icwrs32(n,m,x,s,u)!=i){
+        if(icwrs32(n,m,&v,y,u)!=i){
           fprintf(stderr,"Combination-index mismatch.\n");
           return 1;
         }
-        comb2pulse(n,m,y,x,s);
-        /*for(j=0;j<n;j++)printf(" %c%i",y[j]?y[j]<0?'-':'+':' ',abs(y[j]));
-        printf("\n");*/
-        pulse2comb(n,m,x2,s2,y);
-        for(k=0;k<m;k++)if(x[k]!=x2[k]||s[k]!=s2[k]){
-          fprintf(stderr,"Pulse-combination mismatch.\n");
-          return 1;
+        if(v!=nc){
+          fprintf(stderr,"Combination count mismatch.\n");
+          return 2;
         }
+        /*printf(" %6u\n",i);*/
       }
       /*printf("\n");*/
     }
--- a/tests/cwrs64-test.c
+++ b/tests/cwrs64-test.c
@@ -24,33 +24,24 @@
       if(inc<1)inc=1;
       /*printf("%d/%d: %llu",n,m, nc);*/
       for(i=0;i<nc;i+=inc){
-        celt_uint64_t u[NMAX];
-        int           x[MMAX];
-        int           s[MMAX];
-        int           x2[MMAX];
-        int           s2[MMAX];
+        celt_uint64_t u[NMAX>MMAX+2?NMAX:MMAX+2];
         int           y[NMAX];
+        celt_uint64_t v;
         int           k;
         memcpy(u,uu,n*sizeof(*u));
-        cwrsi64(n,m,i,x,s,u);
+        cwrsi64(n,m,i,nc,y,u);
         /*printf("%llu of %llu:",i,nc);
-        for(k=0;k<m;k++){
-          printf(" %c%i",k>0&&x[k]==x[k-1]?' ':s[k]?'-':'+',x[k]);
-        }
+        for(k=0;k<n;k++)printf(" %+3i",y[k]);
         printf(" ->");*/
-        memcpy(u,uu,n*sizeof(*u));
-        if(icwrs64(n,m,x,s,u)!=i){
+        if(icwrs64(n,m,&v,y,u)!=i){
           fprintf(stderr,"Combination-index mismatch.\n");
           return 1;
         }
-        comb2pulse(n,m,y,x,s);
-        /*for(j=0;j<n;j++)printf(" %c%i",y[j]?y[j]<0?'-':'+':' ',abs(y[j]));
-        printf("\n");*/
-        pulse2comb(n,m,x2,s2,y);
-        for(k=0;k<m;k++)if(x[k]!=x2[k]||s[k]!=s2[k]){
-          fprintf(stderr,"Pulse-combination mismatch.\n");
-          return 1;
+        if(v!=nc){
+          fprintf(stderr,"Combination count mismatch.\n");
+          return 2;
         }
+        /*printf(" %6llu\n",i);*/
       }
       /*printf("\n");*/
     }