shithub: puzzles

Download patch

ref: 6b5142a7a9b31922d9c7ef505b27c33d551f5016
parent: ad7042db989eb525defea9298b2b14d564498473
author: Simon Tatham <[email protected]>
date: Sun Jul 2 17:22:02 EDT 2023

Move mul_root3 out into misc.c and generalise it.

I'm going to want to reuse it for sqrt(5) as well as sqrt(3) soon.

--- a/grid.c
+++ b/grid.c
@@ -3669,131 +3669,6 @@
     tree234 *points;
 };
 
-/*
- * Calculate the nearest integer to n*sqrt(3), via a bitwise algorithm
- * that avoids floating point.
- *
- * (It would probably be OK in practice to use floating point, but I
- * felt like overengineering it for fun. With FP, there's at least a
- * theoretical risk of rounding the wrong way, due to the three
- * successive roundings involved - rounding sqrt(3), rounding its
- * product with n, and then rounding to the nearest integer. This
- * approach avoids that: it's exact.)
- */
-static int mul_root3(int n_signed)
-{
-    unsigned x, r, m;
-    int sign = n_signed < 0 ? -1 : +1;
-    unsigned n = n_signed * sign;
-    unsigned bitpos;
-
-    /*
-     * Method:
-     *
-     * We transform m gradually from zero into n, by multiplying it by
-     * 2 in each step and optionally adding 1, so that it's always
-     * floor(n/2^something).
-     *
-     * At the start of each step, x is the largest integer less than
-     * or equal to m*sqrt(3). We transform m to 2m+bit, and therefore
-     * we must transform x to 2x+something to match. The 'something'
-     * we add to 2x is at most 3. (Worst case is if m sqrt(3) was
-     * equal to x + 1-eps for some tiny eps, and then the incoming bit
-     * of m is 1, so that (2m+1)sqrt(3) = 2x+2+2eps+sqrt(3), i.e.
-     * about 2x + 3.732...)
-     *
-     * To compute this, we also track the residual value r such that
-     * x^2+r = 3m^2.
-     *
-     * The algorithm below is very similar to the usual approach for
-     * taking the square root of an integer in binary. The wrinkle is
-     * that we have an integer multiplier, i.e. we're computing
-     * P*sqrt(Q) (with P=n and Q=3 in this case) rather than just
-     * sqrt(Q). Of course in principle we could just take sqrt(P^2Q),
-     * but we'd need an integer twice the width to hold P^2. Pulling
-     * out P and treating it specially makes overflow less likely.
-     */
-
-    x = r = m = 0;
-
-    for (bitpos = UINT_MAX & ~(UINT_MAX >> 1); bitpos; bitpos >>= 1) {
-        unsigned a, b = (n & bitpos) ? 1 : 0;
-
-        /*
-         * Check invariants. We expect that x^2 + r = 3m^2 (i.e. our
-         * residual term is correct), and also that r < 2x+1 (because
-         * if not, then we could replace x with x+1 and still get a
-         * value that made r non-negative, i.e. x would not be the
-         * _largest_ integer less than m sqrt(3)).
-         */
-        assert(x*x + r == 3*m*m);
-        assert(r < 2*x+1);
-
-        /*
-         * We're going to replace m with 2m+b, and x with 2x+a for
-         * some a we haven't decided on yet.
-         *
-         * The new value of the residual will therefore be
-         *
-         *   3 (2m+b)^2 - (2x+a)^2
-         * = (12m^2 + 12mb + 3b^2) - (4x^2 + 4xa + a^2)
-         * = 4 (3m^2 - x^2) + 12mb + 3b^2 - 4xa - a^2
-         * = 4r + 12mb + 3b^2 - 4xa - a^2          (because r = 3m^2 - x^2)
-         * = 4r + (12m + 3)b - 4xa - a^2           (b is 0 or 1, so b = b^2)
-         */
-        for (a = 0; a < 4; a++) {
-            /* If we made this routine handle square roots of numbers
-             * other than 3 then it would be sensible to make this a
-             * binary search. Here, it hardly seems important. */
-            unsigned pos = 4*r + b*(12*m + 3);
-            unsigned neg = 4*a*x + a*a;
-            if (pos < neg)
-                break;                 /* this value of a is too big */
-        }
-
-        /* The above loop will have terminated with a one too big,
-         * whether that's because we hit the break statement or fell
-         * off the end with a=4. So now decrementing a will give us
-         * the right value to add. */
-        a--;
-
-        r = 4*r + b*(12*m + 3) - (4*a*x + a*a);
-        m = 2*m+b;
-        x = 2*x+a;
-    }
-
-    /*
-     * Finally, round to the nearest integer. At present, x is the
-     * largest integer that is _at most_ m sqrt(3). But we want the
-     * _nearest_ integer, whether that's rounded up or down. So check
-     * whether (x + 1/2) is still less than m sqrt(3), i.e. whether
-     * (x + 1/2)^2 < 3m^2; if it is, then we increment x.
-     *
-     * We have 3m^2 - (x + 1/2)^2 = 3m^2 - x^2 - x - 1/4
-     *                            = r - x - 1/4
-     *
-     * and since r and x are integers, this is greater than 0 if and
-     * only if r > x.
-     *
-     * (There's no need to worry about tie-breaking exact halfway
-     * rounding cases. sqrt(3) is irrational, so none such exist.)
-     */
-    if (r > x)
-        x++;
-
-    /*
-     * Put the sign back on, and convert back from unsigned to int.
-     */
-    if (sign == +1) {
-        return x;
-    } else {
-        /* Be a little careful to avoid compilers deciding I've just
-         * perpetrated signed-integer overflow. This should optimise
-         * down to no actual code. */
-        return INT_MIN + (int)(-x - (unsigned)INT_MIN);
-    }
-}
-
 static void grid_spectres_callback(void *vctx, const int *coords)
 {
     struct spectrecontext *ctx = (struct spectrecontext *)vctx;
@@ -3804,9 +3679,9 @@
         grid_dot *d = grid_get_dot(
             ctx->g, ctx->points,
             (coords[4*i+0] * SPECTRE_UNIT +
-             mul_root3(coords[4*i+1] * SPECTRE_UNIT)),
+             n_times_root_k(coords[4*i+1] * SPECTRE_UNIT, 3)),
             (coords[4*i+2] * SPECTRE_UNIT +
-             mul_root3(coords[4*i+3] * SPECTRE_UNIT)));
+             n_times_root_k(coords[4*i+3] * SPECTRE_UNIT, 3)));
         grid_face_set_dot(ctx->g, d, i);
     }
 }
--- a/misc.c
+++ b/misc.c
@@ -536,4 +536,128 @@
     return path;
 }
 
+/*
+ * Calculate the nearest integer to n*sqrt(k), via a bitwise algorithm
+ * that avoids floating point.
+ *
+ * (It would probably be OK in practice to use floating point, but I
+ * felt like overengineering it for fun. With FP, there's at least a
+ * theoretical risk of rounding the wrong way, due to the three
+ * successive roundings involved - rounding sqrt(k), rounding its
+ * product with n, and then rounding to the nearest integer. This
+ * approach avoids that: it's exact.)
+ */
+int n_times_root_k(int n_signed, int k)
+{
+    unsigned x, r, m;
+    int sign = n_signed < 0 ? -1 : +1;
+    unsigned n = n_signed * sign;
+    unsigned bitpos;
+
+    /*
+     * Method:
+     *
+     * We transform m gradually from zero into n, by multiplying it by
+     * 2 in each step and optionally adding 1, so that it's always
+     * floor(n/2^something).
+     *
+     * At the start of each step, x is the largest integer less than
+     * or equal to m*sqrt(k). We transform m to 2m+bit, and therefore
+     * we must transform x to 2x+something to match. The 'something'
+     * we add to 2x is at most floor(sqrt(k))+2. (Worst case is if m
+     * sqrt(k) was equal to x + 1-eps for some tiny eps, and then the
+     * incoming bit of m is 1, so that (2m+1)sqrt(k) =
+     * 2x+2+sqrt(k)-2eps.)
+     *
+     * To compute this, we also track the residual value r such that
+     * x^2+r = km^2.
+     *
+     * The algorithm below is very similar to the usual approach for
+     * taking the square root of an integer in binary. The wrinkle is
+     * that we have an integer multiplier, i.e. we're computing
+     * n*sqrt(k) rather than just sqrt(k). Of course in principle we
+     * could just take sqrt(n^2k), but we'd need an integer twice the
+     * width to hold n^2. Pulling out n and treating it specially
+     * makes overflow less likely.
+     */
+
+    x = r = m = 0;
+
+    for (bitpos = UINT_MAX & ~(UINT_MAX >> 1); bitpos; bitpos >>= 1) {
+        unsigned a, b = (n & bitpos) ? 1 : 0;
+
+        /*
+         * Check invariants. We expect that x^2 + r = km^2 (i.e. our
+         * residual term is correct), and also that r < 2x+1 (because
+         * if not, then we could replace x with x+1 and still get a
+         * value that made r non-negative, i.e. x would not be the
+         * _largest_ integer less than m sqrt(k)).
+         */
+        assert(x*x + r == k*m*m);
+        assert(r < 2*x+1);
+
+        /*
+         * We're going to replace m with 2m+b, and x with 2x+a for
+         * some a we haven't decided on yet.
+         *
+         * The new value of the residual will therefore be
+         *
+         *   k (2m+b)^2 - (2x+a)^2
+         * = (4km^2 + 4kmb + kb^2) - (4x^2 + 4xa + a^2)
+         * = 4 (km^2 - x^2) + 4kmb + kb^2 - 4xa - a^2
+         * = 4r + 4kmb + kb^2 - 4xa - a^2          (because r = km^2 - x^2)
+         * = 4r + (4m + 1)kb - 4xa - a^2           (b is 0 or 1, so b = b^2)
+         */
+        for (a = 0;; a++) {
+            /* If we made this routine handle square roots of numbers
+             * significantly bigger than 3 or 5 then it would be
+             * sensible to make this a binary search. Here, it hardly
+             * seems important. */
+            unsigned pos = 4*r + k*b*(4*m + 1);
+            unsigned neg = 4*a*x + a*a;
+            if (pos < neg)
+                break;                 /* this value of a is too big */
+        }
+
+        /* The above loop will have terminated with a one too big. So
+         * now decrementing a will give us the right value to add. */
+        a--;
+
+        r = 4*r + b*k*(4*m + 1) - (4*a*x + a*a);
+        m = 2*m+b;
+        x = 2*x+a;
+    }
+
+    /*
+     * Finally, round to the nearest integer. At present, x is the
+     * largest integer that is _at most_ m sqrt(k). But we want the
+     * _nearest_ integer, whether that's rounded up or down. So check
+     * whether (x + 1/2) is still less than m sqrt(k), i.e. whether
+     * (x + 1/2)^2 < km^2; if it is, then we increment x.
+     *
+     * We have km^2 - (x + 1/2)^2 = km^2 - x^2 - x - 1/4
+     *                            = r - x - 1/4
+     *
+     * and since r and x are integers, this is greater than 0 if and
+     * only if r > x.
+     *
+     * (There's no need to worry about tie-breaking exact halfway
+     * rounding cases. sqrt(k) is irrational, so none such exist.)
+     */
+    if (r > x)
+        x++;
+
+    /*
+     * Put the sign back on, and convert back from unsigned to int.
+     */
+    if (sign == +1) {
+        return x;
+    } else {
+        /* Be a little careful to avoid compilers deciding I've just
+         * perpetrated signed-integer overflow. This should optimise
+         * down to no actual code. */
+        return INT_MIN + (int)(-x - (unsigned)INT_MIN);
+    }
+}
+
 /* vim: set shiftwidth=4 tabstop=8: */
--- a/puzzles.h
+++ b/puzzles.h
@@ -391,6 +391,7 @@
 char *fgetline(FILE *fp);
 char *make_prefs_path(const char *dir, const char *sep,
                       const game *game, const char *suffix);
+int n_times_root_k(int n, int k);
 
 /* allocates output each time. len is always in bytes of binary data.
  * May assert (or just go wrong) if lengths are unchecked. */