ref: 44c63350d1916e4f2b3d61304aafbd60519d7d7e
parent: 059749355d990ff3cc02ddcb1da9be146939adaa
author: Jean-Marc Valin <[email protected]>
date: Tue Mar 25 17:28:40 EDT 2008
optimisations: Another bunch of simplifications to alg_quant(), mainly to remove unnecessary copying and some conditional branches.
--- a/libcelt/mathops.h
+++ b/libcelt/mathops.h
@@ -54,6 +54,23 @@
}
#endif
+#ifndef OVERRIDE_FIND_MAX32
+static inline int find_max32(celt_word32_t *x, int len)
+{
+ celt_word32_t max_corr=-VERY_LARGE16;
+ int i, id = 0;
+ for (i=0;i<len;i++)
+ {
+ if (x[i] > max_corr)
+ {
+ id = i;
+ max_corr = x[i];
+ }
+ }
+ return id;
+}
+#endif
+
#ifndef FIXED_POINT
--- a/libcelt/vq.c
+++ b/libcelt/vq.c
@@ -93,30 +93,19 @@
RESTORE_STACK;
}
-/** All the info necessary to keep track of a hypothesis during the search */
-struct NBest {
- celt_word32_t score;
- int sign;
- int pos;
- celt_word32_t xy;
- celt_word32_t yy;
- celt_word32_t yp;
-};
void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, ec_enc *enc)
{
- VARDECL(celt_norm_t, _y);
- VARDECL(celt_norm_t, _ny);
- VARDECL(int, _iy);
- VARDECL(int, _iny);
+ VARDECL(celt_norm_t, y);
+ VARDECL(int, iy);
VARDECL(int, signx);
- celt_norm_t *y, *ny;
- int *iy, *iny;
- int i, j;
+ VARDECL(celt_word32_t, scores);
+ int i, j, is;
+ celt_word16_t s;
int pulsesLeft;
+ celt_word32_t sum;
celt_word32_t xy, yy, yp;
- struct NBest nbest;
- celt_word32_t Rpp=0, Rxp=0;
+ celt_word16_t Rpp;
#ifdef FIXED_POINT
int yshift;
#endif
@@ -126,17 +115,11 @@
yshift = 14-EC_ILOG(K);
#endif
- ALLOC(_y, N, celt_norm_t);
- ALLOC(_ny, N, celt_norm_t);
- ALLOC(_iy, N, int);
- ALLOC(_iny, N, int);
+ ALLOC(y, N, celt_norm_t);
+ ALLOC(iy, N, int);
ALLOC(signx, N, int);
+ ALLOC(scores, N, celt_word32_t);
- y = _y;
- ny = _ny;
- iy = _iy;
- iny = _iny;
-
for (j=0;j<N;j++)
{
if (X[j]>0)
@@ -145,13 +128,12 @@
signx[j]=-1;
}
+ sum = 0;
for (j=0;j<N;j++)
{
- Rpp = MAC16_16(Rpp, P[j],P[j]);
- Rxp = MAC16_16(Rxp, X[j],P[j]);
+ sum = MAC16_16(sum, P[j],P[j]);
}
- Rpp = ROUND16(Rpp, NORM_SHIFT);
- Rxp = ROUND16(Rxp, NORM_SHIFT);
+ Rpp = ROUND16(sum, NORM_SHIFT);
celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
@@ -165,6 +147,9 @@
while (pulsesLeft > 0)
{
int pulsesAtOnce=1;
+ int sign;
+ celt_word32_t Rxy, Ryy, Ryp;
+ celt_word32_t g;
/* Decide on how many pulses to find at once */
pulsesAtOnce = pulsesLeft/N;
@@ -172,31 +157,31 @@
pulsesAtOnce = 1;
/*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
- nbest.score = -VERY_LARGE32;
-
- for (j=0;j<N;j++)
+ /* Choose between fast and accurate strategy depending on where we are in the search */
+ if (pulsesLeft>1)
{
- int sign;
- /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
- celt_word32_t Rxy, Ryy, Ryp;
- celt_word32_t score;
- celt_word32_t g;
- celt_word16_t s;
-
- /* Select sign based on X[j] alone */
- sign = signx[j];
- s = SHL16(sign*pulsesAtOnce, yshift);
-
- /* Updating the sums of the new pulse(s) */
- Rxy = xy + MULT16_16(s,X[j]);
- Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
- Ryp = yp + MULT16_16(s, P[j]);
-
- if (pulsesLeft>1)
+ for (j=0;j<N;j++)
{
- score = MULT32_32_Q31(MULT16_16(ROUND16(Rxy,14),ABS16(ROUND16(Rxy,14))), celt_rcp(SHR32(Ryy,12)));
- } else
+ /* Select sign based on X[j] alone */
+ sign = signx[j];
+ s = SHL16(sign*pulsesAtOnce, yshift);
+ /* Temporary sums of the new pulse(s) */
+ Rxy = xy + MULT16_16(s,X[j]);
+ Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
+ Ryp = yp + MULT16_16(s, P[j]);
+ scores[j] = MULT32_32_Q31(MULT16_16(ROUND16(Rxy,14),ABS16(ROUND16(Rxy,14))), celt_rcp(SHR32(Ryy,12)));
+ }
+ } else {
+ for (j=0;j<N;j++)
{
+ /* Select sign based on X[j] alone */
+ sign = signx[j];
+ s = SHL16(sign*pulsesAtOnce, yshift);
+ /* Temporary sums of the new pulse(s) */
+ Rxy = xy + MULT16_16(s,X[j]);
+ Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
+ Ryp = yp + MULT16_16(s, P[j]);
+
/* Compute the gain such that ||p + g*y|| = 1 */
g = MULT16_32_Q15(
celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
@@ -206,54 +191,23 @@
/* Knowing that gain, what's the error: (x-g*y)^2
(result is negated and we discard x^2 because it's constant) */
/* score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
- score = 2*MULT16_32_Q14(ROUND16(Rxy,14),g)
+ scores[j] = 2*MULT16_32_Q14(ROUND16(Rxy,14),g)
- MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND16(Ryy,14),g)),g);
}
-
- if (score>nbest.score)
- {
- nbest.score = score;
- nbest.pos = j;
- nbest.sign = sign;
- nbest.xy = Rxy;
- nbest.yy = Ryy;
- nbest.yp = Ryp;
- }
}
+
+ j = find_max32(scores, N);
+ is = signx[j]*pulsesAtOnce;
+ s = SHL16(is, yshift);
- celt_assert2(nbest.score > -VERY_LARGE32, "Could not find any match in VQ codebook. Something got corrupted somewhere.");
+ /* Updating the sums of the new pulse(s) */
+ xy = xy + MULT16_16(s,X[j]);
+ yy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
+ yp = yp + MULT16_16(s, P[j]);
- /* Only now that we've made the final choice, update ny/iny and others */
- {
- int n;
- int is;
- celt_norm_t s;
- is = nbest.sign*pulsesAtOnce;
- s = SHL16(is, yshift);
- for (n=0;n<N;n++)
- ny[n] = y[n];
- ny[nbest.pos] += s;
-
- for (n=0;n<N;n++)
- iny[n] = iy[n];
- iny[nbest.pos] += is;
-
- xy = nbest.xy;
- yy = nbest.yy;
- yp = nbest.yp;
- }
- /* Swap ny/iny with y/iy */
- {
- celt_norm_t *tmp_ny;
- int *tmp_iny;
-
- tmp_ny = ny;
- ny = y;
- y = tmp_ny;
- tmp_iny = iny;
- iny = iy;
- iy = tmp_iny;
- }
+ /* Only now that we've made the final choice, update y/iy */
+ y[j] += s;
+ iy[j] += is;
pulsesLeft -= pulsesAtOnce;
}