shithub: opus

Download patch

ref: 5ee9715c5cb9694f0af6aa580ec9855b052d6d9b
parent: d910274f790ba586c40a625bb9120d13554c3443
author: Timothy B.B Terriberry <[email protected]>
date: Sun Sep 21 11:10:58 EDT 2008

Change cwrsi() to operate on rows of U instead of columns.

It is no slower with a large number of pulses, and as much as 30% faster with
 a large number of dimensions.

--- a/libcelt/cwrs.c
+++ b/libcelt/cwrs.c
@@ -264,92 +264,55 @@
 
 /*Returns the number of ways of choosing _m elements from a set of size _n with
    replacement when a sign bit is needed for each unique element.
-  On input, _u should be initialized to column (_m-1) of U(n,m).
-  On exit, _u will be initialized to column _m of U(n,m).*/
-celt_uint32_t ncwrs_unext32(int _n,celt_uint32_t *_ui){
-  celt_uint32_t ret;
-  celt_uint32_t ui0;
-  celt_uint32_t ui1;
-  int           j;
-  ret=ui0=2;
-  celt_assert(_n>=2);
-  j=1; do {
-    ui1=_ui[j]+_ui[j-1]+ui0;
-    _ui[j-1]=ui0;
-    ui0=ui1;
-    ret+=ui0;
-  } while (++j<_n);
-  _ui[j-1]=ui0;
-  return ret;
-}
-
-celt_uint64_t ncwrs_unext64(int _n,celt_uint64_t *_ui){
-  celt_uint64_t ret;
-  celt_uint64_t ui0;
-  celt_uint64_t ui1;
-  int           j;
-  ret=ui0=1;
-  celt_assert(_n>=2);
-  j=1; do {
-    ui1=_ui[j]+_ui[j-1]+ui0;
-    _ui[j-1]=ui0;
-    ui0=ui1;
-    ret+=ui0;
-  } while (++j<_n);
-  _ui[j-1]=ui0;
-  return ret<<=1;
-}
-
-/*Returns the number of ways of choosing _m elements from a set of size _n with
-   replacement when a sign bit is needed for each unique element.
-  _u: On exit, _u[i] contains U(i+1,_m).*/
+  _u: On exit, _u[i] contains U(_n,i) for i in [0..._m+1].*/
 celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
-  celt_uint32_t ret;
   celt_uint32_t um2;
   int           k;
-  /*If _m==0, _u[] should be set to zero and the return should be 1.*/
-  celt_assert(_m>0);
-  /*We'll overflow our buffer unless _n>=2.*/
-  celt_assert(_n>=2);
-  um2=_u[0]=1;
-  if(_m<=6){
-    if(_m<2){
-      k=1;
-      do _u[k]=1;
-      while(++k<_n);
-    }
-    else{
-      k=1;
-      do _u[k]=(k<<1)+1;
-      while(++k<_n);
-      for(k=2;k<_m;k++)unext32(_u,_n,1);
-    }
+  int           len;
+  len=_m+2;
+  _u[0]=0;
+  _u[1]=um2=1;
+  if(_n<=6){
+    /*If _n==0, _u[0] should be 1 and the rest should be 0.*/
+    /*If _n==1, _u[i] should be 1 for i>1.*/
+    celt_assert(_n>=2);
+    /*If _m==0, the following do-while loop will overflow the buffer.*/
+    celt_assert(_m>0);
+    k=2;
+    do _u[k]=(k<<1)-1;
+    while(++k<len);
+    for(k=2;k<_n;k++)unext32(_u+2,_m,(k<<1)+1);
   }
   else{
     celt_uint32_t um1;
     celt_uint32_t n2m1;
-    _u[1]=n2m1=um1=(_m<<1)-1;
-    for(k=2;k<_n;k++){
+    _u[2]=n2m1=um1=(_n<<1)-1;
+    for(k=3;k<len;k++){
       /*U(n,m) = ((2*n-1)*U(n,m-1)-U(n,m-2))/(m-1) + U(n,m-2)*/
-      _u[k]=um2=imusdiv32even(n2m1,um1,um2,k)+um2;
-      if(++k>=_n)break;
-      _u[k]=um1=imusdiv32odd(n2m1,um2,um1,k>>1)+um1;
+      _u[k]=um2=imusdiv32even(n2m1,um1,um2,k-1)+um2;
+      if(++k>=len)break;
+      _u[k]=um1=imusdiv32odd(n2m1,um2,um1,k-1>>1)+um1;
     }
   }
-  ret=1;
-  k=1;
-  do ret+=_u[k];
-  while(++k<_n);
-  return ret<<1;
+  return _u[_m]+_u[_m+1];
 }
 
 celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u){
   int k;
-  CELT_MEMSET(_u,0,_n);
-  if(_m<=0)return 1;
-  if(_n<=0)return 0;
-  for(k=1;k<_m;k++)unext64(_u,_n,1);
-  return ncwrs_unext64(_n,_u);
+  int len;
+  len=_m+2;
+  _u[0]=0;
+  /*If _n==0, _u[0] should be 1 and the rest should be 0.*/
+  /*If _n==1, _u[i] should be 1 for i>1.*/
+  celt_assert(_n>=2);
+  k=1;
+  do _u[k]=(k<<1)-1;
+  while(++k<len);
+  for(k=2;k<_n;k++)unext64(_u+2,_m,(k<<1)+1);
+  /*TODO: For large _n, an imusdiv64 could make this O(_m) instead of
+     O(_n*_m), but would require an INV_TABLE twice as large, as well as lots
+     of 64x64->64 bit multiplies.*/
+  return _u[_m]+_u[_m+1];
 }
 
 
@@ -356,36 +319,28 @@
 /*Returns the _i'th combination of _m elements chosen from a set of size _n
    with associated sign bits.
   _y: Returns the vector of pulses.
-  _u: Must contain entries [1..._n] of column _m of U() on input.
+  _u: Must contain entries [0..._m+1] of row _n of U() on input.
       Its contents will be destructively modified.*/
-void cwrsi32(int _n,int _m,celt_uint32_t _i,celt_uint32_t _nc,int *_y,
- celt_uint32_t *_u){
-  celt_uint32_t p;
-  celt_uint32_t q;
-  int           j;
-  int           k;
+void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_y,celt_uint32_t *_u){
+  int j;
+  int k;
   celt_assert(_n>0);
-  p=_nc;
-  q=0;
   j=0;
   k=_m;
   do{
-    int s;
-    int yj;
-    p-=q;
-    q=_u[_n-j-1];
-    p-=q;
+    celt_uint32_t p;
+    int           s;
+    int           yj;
+    p=_u[k+1];
     s=_i>=p;
     if(s)_i-=p;
     yj=k;
-    while(q>_i){
-      uprev32(_u,_n-j,--k>0);
-      p=q;
-      q=_u[_n-j-1];
-    }
-    _i-=q;
+    p=_u[k];
+    while(p>_i)p=_u[--k];
+    _i-=p;
     yj-=k;
     _y[j]=yj-(yj<<1&-s);
+    uprev32(_u,k+2,0);
   }
   while(++j<_n);
 }
@@ -393,36 +348,28 @@
 /*Returns the _i'th combination of _m elements chosen from a set of size _n
    with associated sign bits.
   _y: Returns the vector of pulses.
-  _u: Must contain entries [1..._n] of column _m of U() on input.
+  _u: Must contain entries [0..._m+1] of row _n of U() on input.
       Its contents will be destructively modified.*/
-void cwrsi64(int _n,int _m,celt_uint64_t _i,celt_uint64_t _nc,int *_y,
- celt_uint64_t *_u){
-  celt_uint64_t p;
-  celt_uint64_t q;
-  int           j;
-  int           k;
+void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_y,celt_uint64_t *_u){
+  int j;
+  int k;
   celt_assert(_n>0);
-  p=_nc;
-  q=0;
   j=0;
   k=_m;
   do{
-    int s;
-    int yj;
-    p-=q;
-    q=_u[_n-j-1];
-    p-=q;
+    celt_uint64_t p;
+    int           s;
+    int           yj;
+    p=_u[k+1];
     s=_i>=p;
     if(s)_i-=p;
     yj=k;
-    while(q>_i){
-      uprev64(_u,_n-j,--k>0);
-      p=q;
-      q=_u[_n-j-1];
-    }
-    _i-=q;
+    p=_u[k];
+    while(p>_i)p=_u[--k];
+    _i-=p;
     yj-=k;
     _y[j]=yj-(yj<<1&-s);
+    uprev64(_u,k+2,0);
   }
   while(++j<_n);
 }
@@ -433,32 +380,26 @@
   _nc: Returns V(_n,_m).*/
 celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y,
  celt_uint32_t *_u){
-  celt_uint32_t nc;
   celt_uint32_t i;
   int           j;
   int           k;
   /*We can't unroll the first two iterations of the loop unless _n>=2.*/
   celt_assert(_n>=2);
-  nc=1;
   i=_y[_n-1]<0;
   _u[0]=0;
   for(k=1;k<=_m+1;k++)_u[k]=(k<<1)-1;
   k=abs(_y[_n-1]);
   j=_n-2;
-  nc+=_u[_m];
   i+=_u[k];
   k+=abs(_y[j]);
   if(_y[j]<0)i+=_u[k+1];
   while(j-->0){
     unext32(_u,_m+2,0);
-    nc+=_u[_m];
     i+=_u[k];
     k+=abs(_y[j]);
     if(_y[j]<0)i+=_u[k+1];
   }
-  /*If _m==0, nc should not be doubled.*/
-  celt_assert(_m>0);
-  *_nc=nc<<1;
+  *_nc=_u[_m]+_u[_m+1];
   return i;
 }
 
@@ -468,32 +409,26 @@
   _nc: Returns V(_n,_m).*/
 celt_uint64_t icwrs64(int _n,int _m,celt_uint64_t *_nc,const int *_y,
  celt_uint64_t *_u){
-  celt_uint64_t nc;
   celt_uint64_t i;
   int           j;
   int           k;
   /*We can't unroll the first two iterations of the loop unless _n>=2.*/
   celt_assert(_n>=2);
-  nc=1;
   i=_y[_n-1]<0;
   _u[0]=0;
   for(k=1;k<=_m+1;k++)_u[k]=(k<<1)-1;
   k=abs(_y[_n-1]);
   j=_n-2;
-  nc+=_u[_m];
   i+=_u[k];
   k+=abs(_y[j]);
   if(_y[j]<0)i+=_u[k+1];
   while(j-->0){
     unext64(_u,_m+2,0);
-    nc+=_u[_m];
     i+=_u[k];
     k+=abs(_y[j]);
     if(_y[j]<0)i+=_u[k+1];
   }
-  /*If _m==0, nc should not be doubled.*/
-  celt_assert(_m>0);
-  *_nc=nc<<1;
+  *_nc=_u[_m]+_u[_m+1];
   return i;
 }
 
@@ -526,7 +461,7 @@
    {
       VARDECL(celt_uint64_t,u);
       SAVE_STACK;
-      ALLOC(u,N,celt_uint64_t);
+      ALLOC(u,K+2,celt_uint64_t);
       nbits = log2_frac64(ncwrs_u64(N,K,u), frac);
       RESTORE_STACK;
    } else {
@@ -564,21 +499,17 @@
 
 static inline void decode_pulse32(int _n,int _m,int *_y,ec_dec *_dec){
   VARDECL(celt_uint32_t,u);
-  celt_uint32_t nc;
   SAVE_STACK;
-  ALLOC(u,_n,celt_uint32_t);
-  nc=ncwrs_u32(_n,_m,u);
-  cwrsi32(_n,_m,ec_dec_uint(_dec,nc),nc,_y,u);
+  ALLOC(u,_m+2,celt_uint32_t);
+  cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_y,u);
   RESTORE_STACK;
 }
 
 static inline void decode_pulse64(int _n,int _m,int *_y,ec_dec *_dec){
   VARDECL(celt_uint64_t,u);
-  celt_uint64_t nc;
   SAVE_STACK;
-  ALLOC(u,_n,celt_uint64_t);
-  nc=ncwrs_u64(_n,_m,u);
-  cwrsi64(_n,_m,ec_dec_uint64(_dec,nc),nc,_y,u);
+  ALLOC(u,_m+2,celt_uint64_t);
+  cwrsi64(_n,_m,ec_dec_uint64(_dec,ncwrs_u64(_n,_m,u)),_y,u);
   RESTORE_STACK;
 }
 
--- a/libcelt/cwrs.h
+++ b/libcelt/cwrs.h
@@ -46,8 +46,7 @@
 /* 32-bit versions */
 celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u);
 
-void cwrsi32(int _n,int _m,celt_uint32_t _i,celt_uint32_t _nc,int *_y,
- celt_uint32_t *_u);
+void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_y,celt_uint32_t *_u);
 
 celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y,
  celt_uint32_t *_u);
@@ -55,8 +54,7 @@
 /* 64-bit versions */
 celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u);
 
-void cwrsi64(int _n,int _m,celt_uint64_t _i,celt_uint64_t _nc,int *_y,
- celt_uint64_t *_u);
+void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_y,celt_uint64_t *_u);
 
 celt_uint64_t icwrs64(int _n,int _m,celt_uint64_t *_nc,const int *_y,
  celt_uint64_t *_u);
--- a/tests/cwrs32-test.c
+++ b/tests/cwrs32-test.c
@@ -13,7 +13,7 @@
   for(n=2;n<=NMAX;n++){
     int m;
     for(m=1;m<=MMAX;m++){
-      celt_uint32_t uu[NMAX];
+      celt_uint32_t uu[MMAX+2];
       celt_uint32_t inc;
       celt_uint32_t nc;
       celt_uint32_t i;
@@ -21,12 +21,12 @@
       inc=nc/10000;
       if(inc<1)inc=1;
       for(i=0;i<nc;i+=inc){
-        celt_uint32_t u[NMAX>MMAX+2?NMAX:MMAX+2];
+        celt_uint32_t u[MMAX+2];
         int           y[NMAX];
         celt_uint32_t v;
         int           k;
-        memcpy(u,uu,n*sizeof(*u));
-        cwrsi32(n,m,i,nc,y,u);
+        memcpy(u,uu,(m+2)*sizeof(*u));
+        cwrsi32(n,m,i,y,u);
         /*printf("%6u of %u:",i,nc);
         for(k=0;k<n;k++)printf(" %+3i",y[k]);
         printf(" ->");*/
--- a/tests/cwrs64-test.c
+++ b/tests/cwrs64-test.c
@@ -14,7 +14,7 @@
   for(n=2;n<=NMAX;n+=3){
     int m;
     for(m=1;m<=MMAX;m++){
-      celt_uint64_t uu[NMAX];
+      celt_uint64_t uu[MMAX+2];
       celt_uint64_t inc;
       celt_uint64_t nc;
       celt_uint64_t i;
@@ -24,12 +24,12 @@
       if(inc<1)inc=1;
       /*printf("%d/%d: %llu",n,m, nc);*/
       for(i=0;i<nc;i+=inc){
-        celt_uint64_t u[NMAX>MMAX+2?NMAX:MMAX+2];
+        celt_uint64_t u[MMAX+2];
         int           y[NMAX];
         celt_uint64_t v;
         int           k;
-        memcpy(u,uu,n*sizeof(*u));
-        cwrsi64(n,m,i,nc,y,u);
+        memcpy(u,uu,(m+2)*sizeof(*u));
+        cwrsi64(n,m,i,y,u);
         /*printf("%llu of %llu:",i,nc);
         for(k=0;k<n;k++)printf(" %+3i",y[k]);
         printf(" ->");*/