ref: 2d5321bb2fa2e34fce928bfef653776aba2c39d3
dir: /libstd/bigint.myr/
use "alloc.use" use "cmp.use" use "die.use" use "extremum.use" use "fmt.use" use "hasprefix.use" use "chartype.use" use "option.use" use "slcp.use" use "sldup.use" use "slfill.use" use "slpush.use" use "types.use" use "utf.use" pkg std = type bigint = struct dig : uint32[:] /* little endian, no leading zeros. */ sign : int /* -1 for -ve, 0 for zero, 1 for +ve. */ ;; /* administrivia */ const mkbigint : (v : int32 -> bigint#) const bigfree : (a : bigint# -> void) const bigdup : (a : bigint# -> bigint#) const bigassign : (d : bigint#, s : bigint# -> bigint#) const bigparse : (s : byte[:] -> option(bigint#)) const bigfmt : (b : byte[:], a : bigint# -> size) /* some useful predicates */ const bigiszero : (a : bigint# -> bool) const bigcmp : (a : bigint#, b : bigint# -> order) /* bigint*bigint -> bigint ops */ const bigadd : (a : bigint#, b : bigint# -> bigint#) const bigsub : (a : bigint#, b : bigint# -> bigint#) const bigmul : (a : bigint#, b : bigint# -> bigint#) const bigdiv : (a : bigint#, b : bigint# -> bigint#) const bigmod : (a : bigint#, b : bigint# -> bigint#) const bigdivmod : (a : bigint#, b : bigint# -> (bigint#, bigint#)) const bigshl : (a : bigint#, b : bigint# -> bigint#) const bigshr : (a : bigint#, b : bigint# -> bigint#) /* bigint*int -> bigint ops */ const bigaddi : (a : bigint#, b : int64 -> bigint#) const bigsubi : (a : bigint#, b : int64 -> bigint#) const bigmuli : (a : bigint#, b : int64 -> bigint#) const bigdivi : (a : bigint#, b : int64 -> bigint#) const bigshli : (a : bigint#, b : uint64 -> bigint#) const bigshri : (a : bigint#, b : uint64 -> bigint#) ;; const Base = 0x100000000ul const mkbigint = {v var a a = zalloc() a.dig = slalloc(1) if v < 0 a.sign = -1 v = -v elif v > 0 a.sign = 1 ;; a.dig[0] = (v castto(uint32)) -> trim(a) } const bigfree = {a slfree(a.dig) free(a) } const bigdup = {a -> bigassign(zalloc(), a) } const bigassign = {d, s slfree(d.dig) d.dig = sldup(s.dig) d.sign = s.sign -> d } /* for now, just dump out something for debugging... */ const bigfmt = {buf, val const digitchars = ['0','1','2','3','4','5','6','7','8','9','a','b','c','d','e','f'] var v var n, i var tmp var rem n = 0 if val.sign == 0 n += encode(buf, '0') elif val.sign == -1 n += encode(buf, '-') ;; val = bigdup(val) /* generate the digits in reverse order */ while !bigiszero(val) (v, rem) = bigdivmod(val, mkbigint(10)) if rem.dig.len > 0 n += bfmt(buf[n:], "%c", digitchars[rem.dig[0]]) else n += bfmt(buf[n:], "0") ;; bigfree(val) bigfree(rem) val = v ;; bigfree(val) /* we only generated ascii digits, so this works. ugh. */ for i = 0; i < n/2; i++ tmp = buf[i] buf[i] = buf[n - i - 1] buf[n - i - 1] = tmp ;; -> n } const bigparse = {str var c, val, base var v, b var a if hasprefix(str, "0x") || hasprefix(str, "0X") base = 16 elif hasprefix(str, "0o") || hasprefix(str, "0O") base = 8 elif hasprefix(str, "0b") || hasprefix(str, "0B") base = 2 else base = 10 ;; a = mkbigint(0) b = mkbigint(base castto(int32)) /* efficiency hack: to save allocations, just mutate v[0]. The value will always fit in one digit. */ v = mkbigint(1) while str.len != 0 (c, str) = striter(str) if c == '_' continue ;; val = charval(c, base) if val < 0 bigfree(a) -> `None ;; v.dig[0] = val if val == 0 v.sign = 0 else v.sign = 1 ;; bigmul(a, b) bigadd(a, v) ;; -> `Some a } const bigiszero = {v -> v.dig.len == 0 } const bigcmp = {a, b var i if a.sign < b.sign -> `Before elif a.sign > b.sign -> `After else /* the one with more digits has greater magnitude */ if a.dig.len > b.dig.len -> signedorder(a.sign) ;; /* otherwise, the one with the first larger digit is bigger */ for i = a.dig.len; i > 0; i-- if a.dig[i - 1] > b.dig[i - 1] -> signedorder(a.sign) elif b.dig[i - 1] > a.dig[i - 1] -> signedorder(a.sign) ;; ;; ;; -> `Equal } const signedorder = {sign if sign < 0 -> `Before else -> `After ;; } /* a += b */ const bigadd = {a, b if a.sign == b.sign || a.sign == 0 a.sign = b.sign -> uadd(a, b) elif b.sign == 0 -> a else match bigcmp(a, b) | `Before: /* a is negative */ a.sign = b.sign -> usub(b, a) | `After: /* b is negative */ -> usub(a, b) | `Equal: die("Impossible. Equal vals with different sign.") ;; ;; } /* adds two unsigned values together. */ const uadd = {a, b var v, i var carry var n carry = 0 n = max(a.dig.len, b.dig.len) /* guaranteed to carry no more than one value */ a.dig = slzgrow(a.dig, n + 1) for i = 0; i < n; i++ v = (a.dig[i] castto(uint64)) + (b.dig[i] castto(uint64)) + carry; if v >= Base carry = 1 else carry = 0 ;; a.dig[i] = v castto(uint32) ;; a.dig[i] += carry castto(uint32) -> trim(a) } /* a -= b */ const bigsub = {a, b /* 0 - x = -x */ if a.sign == 0 bigassign(a, b) a.sign = -b.sign -> a /* x - 0 = x */ elif b.sign == 0 -> a elif a.sign != b.sign -> uadd(a, b) else match bigcmp(a, b) | `Before: /* a is negative */ a.sign = b.sign -> usub(b, a) | `After: /* b is negative */ -> usub(a, b) | `Equal: die("Impossible. Equal vals with different sign.") ;; ;; -> a } /* subtracts two unsigned values, where 'a' is strictly greater than 'b' */ const usub = {a, b var carry var v, i carry = 0 for i = 0; i < a.dig.len; i++ v = (a.dig[i] castto(int64)) - (b.dig[i] castto(int64)) - carry if v < 0 carry = 1 else carry = 0 ;; a.dig[i] = v castto(uint32) ;; -> trim(a) } /* a *= b */ const bigmul = {a, b var i, j var ai, bj, wij var carry, t var w if a.sign == 0 || b.sign == 0 a.sign = 0 slfree(a.dig) a.dig = [][:] -> a elif a.sign != b.sign a.sign = -1 else a.sign = 1 ;; w = slzalloc(a.dig.len + b.dig.len) for j = 0; j < b.dig.len; j++ if a.dig[j] == 0 w[j] = 0 continue ;; carry = 0 for i = 0; i < a.dig.len; i++ ai = a.dig[i] castto(uint64) bj = b.dig[j] castto(uint64) wij = w[i+j] castto(uint64) t = ai * bj + wij + carry w[i + j] = (t castto(uint32)) carry = t >> 32 ;; w[i+j] = carry castto(uint32) ;; slfree(a.dig) a.dig = w -> trim(a) } const bigdiv = {a : bigint#, b : bigint# -> bigint# var q, r (q, r) = bigdivmod(a, b) bigfree(r) slfree(a.dig) a.dig = q.dig free(q) -> a } const bigmod = {a : bigint#, b : bigint# -> bigint# var q, r (q, r) = bigdivmod(a, b) bigfree(q) slfree(a.dig) a.dig = r.dig free(r) -> a } /* a /= b */ const bigdivmod = {a : bigint#, b : bigint# -> (bigint#, bigint#) /* Implements bigint division using Algorithm D from Knuth: Seminumerical algorithms, Section 4.3.1. */ var qhat, rhat, carry, shift var x, y, z, w, p, t /* temporaries */ var b0, aj var u, v var m : int64, n : int64 var i, j : int64 var q if bigiszero(b) die("divide by zero\n") ;; /* if b > a, we trucate to 0, with remainder 'a' */ if a.dig.len < b.dig.len -> (mkbigint(0), bigdup(a)) ;; q = zalloc() q.dig = slzalloc(max(a.dig.len, b.dig.len) + 1) if a.sign != b.sign q.sign = -1 else q.sign = 1 ;; /* handle single digit divisor separately: the knuth algorithm needs at least 2 digits. */ if b.dig.len == 1 carry = 0 b0 = (b.dig[0] castto(uint64)) for j = a.dig.len; j > 0; j-- aj = (a.dig[j - 1] castto(uint64)) q.dig[j - 1] = (((carry << 32) + aj)/b0) castto(uint32) carry = (carry << 32) + aj - (q.dig[j-1] castto(uint64))*b0 ;; -> (trim(q), trim(mkbigint(carry castto(int32)))) ;; u = bigdup(a) v = bigdup(b) m = u.dig.len n = v.dig.len shift = nlz(v.dig[n - 1]) bigshli(u, shift) bigshli(v, shift) for j = m - n; j >= 0; j-- /* load a few temps */ x = u.dig[j + n] castto(uint64) y = u.dig[j + n - 1] castto(uint64) z = v.dig[n - 1] castto(uint64) w = v.dig[n - 2] castto(uint64) t = u.dig[j + n - 2] castto(uint64) /* estimate qhat */ qhat = (x*Base + y)/z rhat = (x*Base + y) - (qhat * z) :divagain if qhat >= Base || (qhat * w) > (rhat*Base + t) qhat-- rhat += z if rhat < Base goto divagain ;; ;; /* multiply and subtract */ carry = 0 for i = 0; i < n; i++ p = qhat * (v.dig[i] castto(uint64)) t = (u.dig[i+j] castto(uint64)) - carry - (p % Base) u.dig[i+j] = t castto(uint32) carry = (((p castto(int64)) >> 32) - ((t castto(int64)) >> 32)) castto(uint64); ;; x = u.dig[j + n] castto(uint64) t = x - carry u.dig[j + n] = t castto(uint32) q.dig[j] = qhat castto(uint32) /* adjust */ if x < carry q.dig[j]-- carry = 0 for i = 0; i < n; i++ t = (u.dig[i+j] castto(uint64)) + (v.dig[i] castto(uint64)) + carry u.dig[i+j] = t castto(uint32) carry = t >> 32 ;; u.dig[j+n] = u.dig[j+n] + (carry castto(uint32)); ;; ;; /* undo the biasing for remainder */ bigshri(u, shift) -> (trim(q), trim(u)) } /* returns the number of leading zeros */ const nlz = {a : uint32 var n if a == 0 -> 32 ;; n = 0 if a <= 0x0000ffff n += 16 a <<= 16 ;; if a <= 0x00ffffff n += 8 a <<= 8 ;; if a <= 0x0fffffff n += 4 a <<= 4 ;; if a <= 0x3fffffff n += 2 a <<= 2 ;; if a <= 0x7fffffff n += 1 a <<= 1 ;; -> n } /* a <<= b */ const bigshl = {a, b match b.dig.len | 0: -> a | 1: -> bigshli(a, b.dig[0] castto(uint64)) | n: die("shift by way too much\n") ;; } /* a >>= b, unsigned */ const bigshr = {a, b match b.dig.len | 0: -> a | 1: -> bigshri(a, b.dig[0] castto(uint64)) | n: die("shift by way too much\n") ;; } /* a << s, with integer arg. logical left shift. any other type would be illogical. */ const bigshli = {a, s var off, shift var t, carry var i off = s/32 shift = s % 32 a.dig = slzgrow(a.dig, 1 + a.dig.len + off castto(size)) /* blit over the base values */ for i = a.dig.len; i > off; i-- a.dig[i - 1] = a.dig[i - 1 - off] ;; for i = 0; i < off; i++ a.dig[i] = 0 ;; /* and shift over by the remainder */ carry = 0 for i = 0; i < a.dig.len; i++ t = (a.dig[i] castto(uint64)) << shift a.dig[i] = (t | carry) castto(uint32) carry = t >> 32 ;; -> trim(a) } /* logical shift right, zero fills. sign remains untouched. */ const bigshri = {a, s var off, shift var t, carry var i off = s/32 shift = s % 32 /* blit over the base values */ for i = 0; i < a.dig.len - off; i++ a.dig[i] = a.dig[i + off] ;; for i = a.dig.len; i < a.dig.len + off; i++ a.dig[i] = 0 ;; /* and shift over by the remainder */ carry = 0 for i = a.dig.len; i > 0; i-- t = (a.dig[i - 1] castto(uint64)) a.dig[i - 1] = (carry | (t >> shift)) castto(uint32) carry = t << (32 - shift) ;; -> trim(a) } /* trims leading zeros */ const trim = {a var i for i = a.dig.len; i > 0; i-- if a.dig[i - 1] != 0 break ;; ;; a.dig = slgrow(a.dig, i) if i == 0 a.sign = 0 elif a.sign == 0 a.sign = 1 ;; -> a }