ref: 6df4ee6319d8c943d766e99f231908831421e403
parent: 8acfc5e2a00efe3866d7c5aa9201d4bd0fe9216e
parent: c98dd1239d76a56d36ce6511239ccfce2b5c6e43
author: Ori Bernstein <[email protected]>
date: Thu Jan 23 05:42:47 EST 2014
Merge branch 'master' of git+ssh://git.eigenstate.org/git/ori/mc
--- a/libstd/bigint.myr
+++ b/libstd/bigint.myr
@@ -3,9 +3,12 @@
use "die.use"
use "extremum.use"
use "fmt.use"
+use "hasprefix.use"
+use "chartype.use"
use "option.use"
use "slcp.use"
use "sldup.use"
+use "slfill.use"
use "slpush.use"
use "types.use"
use "utf.use"
@@ -20,7 +23,8 @@
const mkbigint : (v : int32 -> bigint#)
const bigfree : (a : bigint# -> void)
const bigdup : (a : bigint# -> bigint#)
- const bigparse : (a : bigint# -> option(byte[:]))
+ const bigassign : (d : bigint#, s : bigint# -> bigint#)
+ const bigparse : (s : byte[:] -> option(bigint#))
const bigfmt : (b : byte[:], a : bigint# -> size)
/* some useful predicates */
@@ -36,7 +40,6 @@
const bigdivmod : (a : bigint#, b : bigint# -> [bigint#, bigint#])
const bigshl : (a : bigint#, b : bigint# -> bigint#)
const bigshr : (a : bigint#, b : bigint# -> bigint#)
- const bigshra : (a : bigint#, b : bigint# -> bigint#)
/* bigint*int -> bigint ops */
const bigaddi : (a : bigint#, b : int64 -> bigint#)
@@ -45,7 +48,6 @@
const bigdivi : (a : bigint#, b : int64 -> bigint#)
const bigshli : (a : bigint#, b : uint64 -> bigint#)
const bigshri : (a : bigint#, b : uint64 -> bigint#)
- const bigshrai : (a : bigint#, b : uint64 -> bigint#)
;;
const mkbigint = {v
@@ -69,14 +71,17 @@
}
const bigdup = {a
- var v
+ -> bigassign(zalloc(), a)
+}
- v = zalloc()
- v.dig = sldup(a.dig)
- v.sign = a.sign
- -> v
+const bigassign = {d, s
+ slfree(d.dig)
+ d.dig = sldup(s.dig)
+ d.sign = s.sign
+ -> d
}
+
/* for now, just dump out something for debugging... */
const bigfmt = {buf, val
const digitchars = ['0','1','2','3','4','5','6','7','8','9','a','b','c','d','e','f']
@@ -88,7 +93,6 @@
n = 0
if val.sign == 0
n += encode(buf, '0')
- -> n
elif val.sign == -1
n += encode(buf, '-')
;;
@@ -116,39 +120,51 @@
-> n
}
-/*
const bigparse = {str
- var c, i
- var val, base
+ var c, val, base
+ var v, b
var a
if hasprefix(str, "0x") || hasprefix(str, "0X")
base = 16
- if hasprefix(str, "0o") || hasprefix(str, "0O")
+ elif hasprefix(str, "0o") || hasprefix(str, "0O")
base = 8
- if hasprefix(str, "0b") || hasprefix(str, "0B")
+ elif hasprefix(str, "0b") || hasprefix(str, "0B")
base = 2
else
base = 10
;;
- a = mkbigint()
+ a = mkbigint(0)
+ b = mkbigint(base castto(int32))
+ /*
+ efficiency hack: to save allocations,
+ just mutate v[0]. The value will always
+ fit in one digit.
+ */
+ v = mkbigint(1)
while str.len != 0
(c, str) = striter(str)
- val = charval(c)
+ val = charval(c, base)
if val < 0
bigfree(a)
-> `None
;;
- bigmuli(a, base)
- bigaddi(a, val)
+ v.dig[0] = val
+ if val == 0
+ v.sign = 0
+ else
+ v.sign = 1
+ ;;
+ bigmul(a, b)
+ bigadd(a, v)
+
;;
-> `Some a
}
-*/
const bigiszero = {v
- -> v.sign == 0
+ -> v.dig.len == 0
}
const bigcmp = {a, b
@@ -161,14 +177,14 @@
else
/* the one with more digits has greater magnitude */
if a.dig.len > b.dig.len
- -> signedmagorder(a.sign)
+ -> signedorder(a.sign)
;;
/* otherwise, the one with the first larger digit is bigger */
for i = a.dig.len; i > 0; i--
if a.dig[i - 1] > b.dig[i - 1]
- -> signedmagorder(a.sign)
+ -> signedorder(a.sign)
elif b.dig[i - 1] > a.dig[i - 1]
- -> signedmagorder(a.sign)
+ -> signedorder(a.sign)
;;
;;
;;
@@ -175,9 +191,9 @@
-> `Equal
}
-const signedmagorder = {sign
+const signedorder = {sign
if sign < 0
- -> `Before
+ -> `Before
else
-> `After
;;
@@ -185,8 +201,11 @@
/* a += b */
const bigadd = {a, b
- if a.sign == b.sign
+ if a.sign == b.sign || a.sign == 0
+ a.sign = b.sign
-> uadd(a, b)
+ elif b.sign == 0
+ -> a
else
match bigcmp(a, b)
| `Before: /* a is negative */
@@ -207,9 +226,9 @@
var n
carry = 0
- n = min(a.dig.len, b.dig.len)
+ n = max(a.dig.len, b.dig.len)
/* guaranteed to carry no more than one value */
- a.dig = slpush(a.dig, 0)
+ a.dig = slzgrow(a.dig, n + 1)
for i = 0; i < n; i++
v = (a.dig[i] castto(uint64)) + (b.dig[i] castto(uint64)) + carry;
if v > (0xffffffff castto(uint64))
@@ -225,7 +244,15 @@
/* a -= b */
const bigsub = {a, b
- if a.sign != b.sign
+ /* 0 - x = -x */
+ if a.sign == 0
+ bigassign(a, b)
+ a.sign = -b.sign
+ -> a
+ /* x - 0 = x */
+ elif b.sign == 0
+ -> a
+ elif a.sign != b.sign
-> uadd(a, b)
else
match bigcmp(a, b)
@@ -266,7 +293,12 @@
var carry, t
var w
- if a.sign != b.sign
+ if a.sign == 0 || b.sign == 0
+ a.sign = 0
+ slfree(a.dig)
+ a.dig = [][:]
+ -> a
+ elif a.sign != b.sign
a.sign = -1
else
a.sign = 1
@@ -316,19 +348,25 @@
/* a /= b */
const bigdivmod = {a : bigint#, b : bigint# -> [bigint#, bigint#]
- /*var u, v /* normalized a and b */*/
- var q/*, qhat /* quotient and estimated quotient */*/
/*
- var rhat /* remainder */
- var s, i, j, t
- var p
- */
- var j
- var carry, b0, aj
+ Implements bigint division using Algorithm D from
+ Knuth: Seminumerical algorithms, Section 4.3.1.
+ */
+ var qhat, rhat, carry, shift
+ var x, y, z, w, p, t /* temporaries */
+ var b0, aj
+ var u, v
+ var m : int64, n : int64
+ var i, j : int64
+ var q
if bigiszero(b)
die("divide by zero\n")
;;
+ /* if b > a, we trucate to 0, with remainder 'a' */
+ if a.dig.len < b.dig.len
+ -> (mkbigint(0), bigdup(a))
+ ;;
q = zalloc()
q.dig = slzalloc(max(a.dig.len, b.dig.len))
@@ -347,14 +385,99 @@
q.dig[j - 1] = (((carry << 32) + aj)/b0) castto(uint32)
carry = (carry << 32) + aj - (q.dig[j-1] castto(uint64))*b0
;;
- for v in q.dig
- ;;
-> (trim(q), trim(mkbigint(carry castto(int32))))
;;
- die("big bigint division not implemented\n")
- -> (trim(a), mkbigint(carry castto(int32)))
+
+ u = bigdup(a)
+ v = bigdup(b)
+ q = zalloc()
+ q.dig = slalloc(max(u.dig.len, v.dig.len))
+ m = u.dig.len
+ n = v.dig.len
+
+ shift = nlz(v.dig[n - 1])
+ bigshli(u, shift)
+ bigshli(v, shift)
+ for j = m - n; j >= 0; j--
+ /* load a few temps */
+ x = (u.dig[j + n] castto(uint64)) << 32
+ y = u.dig[j + n - 1] castto(uint64)
+ z = v.dig[n - 1] castto(uint64)
+ w = v.dig[n - 2] castto(uint64)
+
+ /* estimate qhat */
+ qhat = (x + y)/z
+ rhat = (x + y) - (qhat * z)
+:divagain
+ if qhat > 0xfffffffful || qhat * w > ((rhat << 32) + w)
+ qhat--
+ rhat += z
+ if rhat <= 0xfffffffful
+ goto divagain
+ ;;
+ ;;
+
+ /* multiply and subtract */
+ carry = 0
+ for i = 0; i < n; i++
+ p = qhat * (v.dig[i] castto(uint64))
+ t = (u.dig[i+j] castto(uint64)) - carry - (p & (0xffffffff castto(uint64)))
+ u.dig[i+j] = t castto(uint32)
+ carry = (p >> 32) - (t >> 32);
+ ;;
+ t = x - carry
+ u.dig[j + n] = t castto(uint32)
+ q.dig[j] = qhat castto(uint32)
+ /* adjust */
+ if x < carry
+ q.dig[j]--
+ carry = 0
+ for i = 0; i < n; i++
+ t = (u.dig[i+j] castto(uint64)) + (v.dig[i] castto(uint64)) + carry
+ u.dig[i+j] = t castto(uint32)
+ carry = t >> 32
+ ;;
+ u.dig[j+n] = u.dig[j+n] + (carry castto(uint32));
+ ;;
+
+ ;;
+ /* undo the biasing for remainder */
+ bigshri(u, shift)
+ -> (trim(q), trim(u))
}
+/* returns the number of leading zeros */
+const nlz = {a : uint32
+ var n
+
+ if a == 0
+ -> 32
+ ;;
+ n = 0
+ if a < 0x0000ffff
+ n += 16
+ a <<= 16
+ ;;
+ if a < 0x00ffffff
+ n += 8
+ a <<= 8
+ ;;
+ if a < 0x0fffffff
+ n += 4
+ a <<= 4
+ ;;
+ if a < 0x3fffffff
+ n += 2
+ a <<= 2
+ ;;
+ if a < 0x7fffffff
+ n += 1
+ a <<= 1
+ ;;
+ -> n
+}
+
+
/* a <<= b */
const bigshl = {a, b
match b.dig.len
@@ -373,30 +496,10 @@
;;
}
-/* a >>= b, sign extending */
-const bigshra = {a, b
- match b.dig.len
- | 0: -> a
- | 1: -> bigshrai(a, b.dig[0] castto(uint64))
- | n: die("shift by way too much\n")
- ;;
-}
-
-const trim = {a
- var i
-
- for i = a.dig.len; i > 0; i--
- if a.dig[i - 1] != 0
- break
- ;;
- ;;
- a.dig = slgrow(a.dig, i)
- if i == 0
- a.sign = 0
- ;;
- -> a
-}
-
+/*
+ a << s, with integer arg.
+ logical left shift. any other type would be illogical.
+ */
const bigshli = {a, s
var off, shift
var t, carry
@@ -423,19 +526,8 @@
-> trim(a)
}
+/* logical shift right, zero fills. sign remains untouched. */
const bigshri = {a, s
- -> bigshrfill(a, s, 0)
-}
-
-const bigshrai = {a, s
- if a.sign == -1
- -> bigshrfill(a, s, ~0)
- else
- -> bigshrfill(a, s, 0)
- ;;
-}
-
-const bigshrfill = {a, s, fill
var off, shift
var t, carry
var i
@@ -448,7 +540,7 @@
a.dig[i] = a.dig[i + off]
;;
for i = a.dig.len; i < a.dig.len + off; i++
- a.dig[i] = fill
+ a.dig[i] = 0
;;
/* and shift over by the remainder */
carry = 0
@@ -459,3 +551,22 @@
;;
-> trim(a)
}
+
+/* trims leading zeros */
+const trim = {a
+ var i
+
+ for i = a.dig.len; i > 0; i--
+ if a.dig[i - 1] != 0
+ break
+ ;;
+ ;;
+ a.dig = slgrow(a.dig, i)
+ if i == 0
+ a.sign = 0
+ elif a.sign == 0
+ a.sign = 1
+ ;;
+ -> a
+}
+
--- a/parse/infer.c
+++ b/parse/infer.c
@@ -980,6 +980,22 @@
case Olit:
case Omemb:
infernode(st, n, NULL, NULL); break;
+ /* arithmetic expressions just need to be constant */
+ case Oneg:
+ case Oadd:
+ case Osub:
+ case Omul:
+ case Odiv:
+ case Obsl:
+ case Obsr:
+ case Oband:
+ case Obor:
+ case Obxor:
+ case Obnot:
+ infernode(st, n, NULL, NULL);
+ if (!n->expr.isconst)
+ fatal(n->line, "matching against non-constant expression");
+ break;
case Oucon: inferucon(st, n, &n->expr.isconst); break;
case Ovar:
s = getdcl(curstab(), args[0]);
@@ -1001,7 +1017,7 @@
n->expr.did = s->decl.did;
break;
default:
- die("Bad pattern to match against");
+ fatal(n->line, "invalid pattern");
break;
}
}