shithub: mc

ref: 894bed9a153ab56763627975f9e109281bec297b
dir: /libstd/bigint.myr/

View raw version
use "alloc.use"
use "cmp.use"
use "die.use"
use "extremum.use"
use "fmt.use"
use "hasprefix.use"
use "chartype.use"
use "option.use"
use "slcp.use"
use "sldup.use"
use "slfill.use"
use "slpush.use"
use "types.use"
use "utf.use"

pkg std =
	type bigint = struct
		dig	: uint32[:] 	/* little endian, no leading zeros. */
		sign	: int		/* -1 for -ve, 0 for zero, 1 for +ve. */
	;;

	/* administrivia */
	const mkbigint	: (v : int32 -> bigint#)
	const bigfree	: (a : bigint# -> void)
	const bigdup	: (a : bigint# -> bigint#)
	const bigassign	: (d : bigint#, s : bigint# -> bigint#)
	const bigparse	: (s : byte[:] -> option(bigint#))
	const bigfmt	: (b : byte[:], a : bigint# -> size)

	/* some useful predicates */
	const bigiszero	: (a : bigint# -> bool)
	const bigcmp	: (a : bigint#, b : bigint# -> order)

	/* bigint*bigint -> bigint ops */
	const bigadd	: (a : bigint#, b : bigint# -> bigint#)
	const bigsub	: (a : bigint#, b : bigint# -> bigint#)
	const bigmul	: (a : bigint#, b : bigint# -> bigint#)
	const bigdiv	: (a : bigint#, b : bigint# -> bigint#)
	const bigmod	: (a : bigint#, b : bigint# -> bigint#)
	const bigdivmod	: (a : bigint#, b : bigint# -> (bigint#, bigint#))
	const bigshl	: (a : bigint#, b : bigint# -> bigint#)
	const bigshr	: (a : bigint#, b : bigint# -> bigint#)

	/* bigint*int -> bigint ops */
	const bigaddi	: (a : bigint#, b : int64 -> bigint#)
	const bigsubi	: (a : bigint#, b : int64 -> bigint#)
	const bigmuli	: (a : bigint#, b : int64 -> bigint#)
	const bigdivi	: (a : bigint#, b : int64 -> bigint#)
	const bigshli	: (a : bigint#, b : uint64 -> bigint#)
	const bigshri	: (a : bigint#, b : uint64 -> bigint#)
;;

const Base = 0x100000000ul

const mkbigint = {v
	var a
	a = zalloc()

	a.dig = slalloc(1)
	if v < 0
		a.sign = -1
		v = -v
	elif v > 0
		a.sign = 1
	;;
	a.dig[0] = (v castto(uint32))
	-> trim(a)
}

const bigfree = {a
	slfree(a.dig)
	free(a)
}

const bigdup = {a
	-> bigassign(zalloc(), a)
}

const bigassign = {d, s
	slfree(d.dig)
	d.dig = sldup(s.dig)
	d.sign = s.sign
	-> d
}


/* for now, just dump out something for debugging... */
const bigfmt = {buf, val
	const digitchars = ['0','1','2','3','4','5','6','7','8','9','a','b','c','d','e','f']
	var v
	var n, i
	var tmp
	var rem

	n = 0
	if val.sign == 0
		n += encode(buf, '0')
	elif val.sign == -1
		n += encode(buf, '-')
	;;

	val = bigdup(val)
	/* generate the digits in reverse order */
	while !bigiszero(val)
		(v, rem) = bigdivmod(val, mkbigint(10))
		if rem.dig.len > 0
			n += bfmt(buf[n:], "%c", digitchars[rem.dig[0]])
		else
			n += bfmt(buf[n:], "0")
		;;
		bigfree(val)
		bigfree(rem)
		val = v
	;;
	bigfree(val)
	/* we only generated ascii digits, so this works. ugh. */
	for i = 0; i < n/2; i++
		tmp = buf[i]
		buf[i] = buf[n - i - 1]
		buf[n - i - 1] = tmp
	;;
	-> n
}

const bigparse = {str
	var c, val, base
	var v, b
	var a

	if hasprefix(str, "0x") || hasprefix(str, "0X")
		base = 16
	elif hasprefix(str, "0o") || hasprefix(str, "0O")
		base = 8
	elif hasprefix(str, "0b") || hasprefix(str, "0B")
		base = 2
	else
		base = 10
	;;

	a = mkbigint(0)
	b = mkbigint(base castto(int32))
	/*
	 efficiency hack: to save allocations,
	 just mutate v[0]. The value will always
	 fit in one digit.
	 */
	v = mkbigint(1)
	while str.len != 0
		(c, str) = striter(str)
		if c == '_'
			continue
		;;
		val = charval(c, base)
		if val < 0
			bigfree(a)
			-> `None
		;;
		v.dig[0] = val
		if val == 0
			v.sign = 0
		else
			v.sign = 1
		;;
		bigmul(a, b)
		bigadd(a, v)

	;;
	-> `Some a
}

const bigiszero = {v
	-> v.dig.len == 0
}

const bigcmp = {a, b
	var i

	if a.sign < b.sign
		-> `Before
	elif a.sign > b.sign
		-> `After
	else
		/* the one with more digits has greater magnitude */
		if a.dig.len > b.dig.len
			-> signedorder(a.sign)
		;;
		/* otherwise, the one with the first larger digit is bigger */
		for i = a.dig.len; i > 0; i--
			if a.dig[i - 1] > b.dig[i - 1]
				-> signedorder(a.sign)
			elif b.dig[i - 1] > a.dig[i - 1]
				-> signedorder(a.sign)
			;;
		;;
	;;
	-> `Equal
}

const signedorder = {sign
	if sign < 0
		-> `Before 
	else
		-> `After
	;;
}

/* a += b */
const bigadd = {a, b
	if a.sign == b.sign || a.sign == 0 
		a.sign = b.sign
		-> uadd(a, b)
	elif b.sign == 0
		-> a
	else
		match bigcmp(a, b)
		| `Before: /* a is negative */
		    a.sign = b.sign
		    -> usub(b, a)
		| `After: /* b is negative */
		    -> usub(a, b)
		| `Equal:
			die("Impossible. Equal vals with different sign.")
		;;
	;;
}

/* adds two unsigned values together. */
const uadd = {a, b
	var v, i
	var carry
	var n

	carry = 0
	n = max(a.dig.len, b.dig.len)
	/* guaranteed to carry no more than one value */
	a.dig = slzgrow(a.dig, n + 1)
	for i = 0; i < n; i++
		v = (a.dig[i] castto(uint64)) + (b.dig[i] castto(uint64)) + carry;
		if v >= Base
			carry = 1
		else
			carry = 0
		;;
		a.dig[i] = v castto(uint32)
	;;
	a.dig[i] += carry castto(uint32)
	-> trim(a)
}

/* a -= b */
const bigsub = {a, b
	/* 0 - x = -x */
	if a.sign == 0
		bigassign(a, b)
		a.sign = -b.sign
		-> a
	/* x - 0 = x */
	elif b.sign == 0
		-> a
	elif a.sign != b.sign
		-> uadd(a, b)
	else
		match bigcmp(a, b)
		| `Before: /* a is negative */
		    a.sign = b.sign
		    -> usub(b, a)
		| `After: /* b is negative */
		    -> usub(a, b)
		| `Equal:
			die("Impossible. Equal vals with different sign.")
		;;
	;;
	-> a
}

/* subtracts two unsigned values, where 'a' is strictly greater than 'b' */
const usub = {a, b
	var carry
	var v, i

	carry = 0
	for i = 0; i < a.dig.len; i++
		v = (a.dig[i] castto(int64)) - (b.dig[i] castto(int64)) - carry
		if v < 0
			carry = 1
		else
			carry = 0
		;;
		a.dig[i] = v castto(uint32)
	;;
	-> trim(a)
}

/* a *= b */
const bigmul = {a, b
	var i, j
	var ai, bj, wij
	var carry, t
	var w

	if a.sign == 0 || b.sign == 0
		a.sign = 0
		slfree(a.dig)
		a.dig = [][:]
		-> a
	elif a.sign != b.sign
		a.sign = -1
	else
		a.sign = 1
	;;
	w  = slzalloc(a.dig.len + b.dig.len)
	for j = 0; j < b.dig.len; j++
		if a.dig[j] == 0
			w[j] = 0
			continue
		;;
		carry = 0
		for i = 0; i < a.dig.len; i++
			ai = a.dig[i] castto(uint64)
			bj = b.dig[j] castto(uint64)
			wij = w[i+j] castto(uint64)
			t = ai * bj + wij + carry
			w[i + j] = (t castto(uint32))
			carry = t >> 32
		;;
		w[i+j] = carry castto(uint32)
	;;
	slfree(a.dig)
	a.dig = w
	-> trim(a)
}

const bigdiv = {a : bigint#, b : bigint# -> bigint#
	var q, r

	(q, r) = bigdivmod(a, b)
	bigfree(r)
	slfree(a.dig)
	a.dig = q.dig
	free(q)
	-> a
}
const bigmod = {a : bigint#, b : bigint# -> bigint#
	var q, r

	(q, r) = bigdivmod(a, b)
	bigfree(q)
	slfree(a.dig)
	a.dig = r.dig
	free(r)
	-> a
}

/* a /= b */
const bigdivmod = {a : bigint#, b : bigint# -> (bigint#, bigint#)
	/*
	 Implements bigint division using Algorithm D from
	 Knuth: Seminumerical algorithms, Section 4.3.1.
	 */
	var qhat, rhat, carry, shift
	var x, y, z, w, p, t /* temporaries */
	var b0, aj
	var u, v
	var m : int64, n : int64
	var i, j : int64
	var q

	if bigiszero(b)
		die("divide by zero\n")
	;;
	/* if b > a, we trucate to 0, with remainder 'a' */
	if a.dig.len < b.dig.len
		-> (mkbigint(0), bigdup(a))
	;;

	q = zalloc()
	q.dig = slzalloc(max(a.dig.len, b.dig.len) + 1)
	if a.sign != b.sign
		q.sign = -1
	else
		q.sign = 1
	;;

	/* handle single digit divisor separately: the knuth algorithm needs at least 2 digits. */
	if b.dig.len == 1
		carry = 0
		b0 = (b.dig[0] castto(uint64))
		for j = a.dig.len; j > 0; j--
			aj = (a.dig[j - 1] castto(uint64))
			q.dig[j - 1] = (((carry << 32) + aj)/b0) castto(uint32)
			carry = (carry << 32) + aj - (q.dig[j-1] castto(uint64))*b0
		;;
		-> (trim(q), trim(mkbigint(carry castto(int32))))
	;;

	u = bigdup(a)
	v = bigdup(b)
	m = u.dig.len
	n = v.dig.len

	shift = nlz(v.dig[n - 1])
	bigshli(u, shift)
	bigshli(v, shift)
	for j = m - n; j >= 0; j--
		/* load a few temps */
		x = u.dig[j + n] castto(uint64)
		y = u.dig[j + n - 1] castto(uint64)
		z = v.dig[n - 1] castto(uint64)
		w = v.dig[n - 2] castto(uint64)
		t = u.dig[j + n - 2] castto(uint64)

		/* estimate qhat */
		qhat = (x*Base + y)/z
		rhat = (x*Base + y) - (qhat * z)
:divagain
		if qhat >= Base || (qhat * w) > (rhat*Base + t)
			qhat--
			rhat += z
			if rhat < Base
				goto divagain
			;;
		;;

		/* multiply and subtract */
		carry = 0
		for i = 0; i < n; i++
			p = qhat * (v.dig[i] castto(uint64))
			t = (u.dig[i+j] castto(uint64)) - carry - (p % Base)
			u.dig[i+j] = t castto(uint32)
			carry = (((p castto(int64)) >> 32) - ((t castto(int64)) >> 32)) castto(uint64);
		;;
		x = u.dig[j + n] castto(uint64)
		t = x - carry
		u.dig[j + n] = t castto(uint32)
		q.dig[j] = qhat castto(uint32)
		/* adjust */
		if x < carry
			q.dig[j]--
			carry = 0
			for i = 0; i < n; i++
				t = (u.dig[i+j] castto(uint64)) + (v.dig[i] castto(uint64)) + carry
				u.dig[i+j] = t castto(uint32)
				carry = t >> 32
			;;
			u.dig[j+n] = u.dig[j+n] + (carry castto(uint32));
		;;

	;;
	/* undo the biasing for remainder */
	bigshri(u, shift)
	-> (trim(q), trim(u))
}

/* returns the number of leading zeros */
const nlz = {a : uint32
	var n

	if a == 0
		-> 32
	;;
	n = 0
	if a <= 0x0000ffff
		n += 16
		a <<= 16
	;;
	if a <= 0x00ffffff
		n += 8
		a <<= 8
	;;
	if a <= 0x0fffffff
		n += 4
		a <<= 4
	;;
	if a <= 0x3fffffff
		n += 2
		a <<= 2
	;;
	if a <= 0x7fffffff
		n += 1
		a <<= 1
	;;
	-> n
}


/* a <<= b */
const bigshl = {a, b
	match b.dig.len
	| 0:	-> a
	| 1:	-> bigshli(a, b.dig[0] castto(uint64))
	| n:	die("shift by way too much\n")
	;;
}

/* a >>= b, unsigned */
const bigshr = {a, b
	match b.dig.len
	| 0:	-> a
	| 1:	-> bigshri(a, b.dig[0] castto(uint64))
	| n:	die("shift by way too much\n")
	;;
}

/* 
  a << s, with integer arg.
  logical left shift. any other type would be illogical.
 */
const bigshli = {a, s
	var off, shift
	var t, carry
	var i

	off = s/32
	shift = s % 32

	a.dig = slzgrow(a.dig, 1 + a.dig.len + off castto(size))
	/* blit over the base values */
	for i = a.dig.len; i > off; i--
		a.dig[i - 1] = a.dig[i - 1 - off]
	;;
	for i = 0; i < off; i++
		a.dig[i] = 0
	;;
	/* and shift over by the remainder */
	carry = 0
	for i = 0; i < a.dig.len; i++
		t = (a.dig[i] castto(uint64)) << shift
		a.dig[i] = (t | carry) castto(uint32) 
		carry = t >> 32
	;;
	-> trim(a)
}

/* logical shift right, zero fills. sign remains untouched. */
const bigshri = {a, s
	var off, shift
	var t, carry
	var i

	off = s/32
	shift = s % 32

	/* blit over the base values */
	for i = 0; i < a.dig.len - off; i++
		a.dig[i] = a.dig[i + off]
	;;
	for i = a.dig.len; i < a.dig.len + off; i++
		a.dig[i] = 0
	;;
	/* and shift over by the remainder */
	carry = 0
	for i = a.dig.len; i > 0; i--
		t = (a.dig[i - 1] castto(uint64))
		a.dig[i - 1] = (carry | (t >> shift)) castto(uint32) 
		carry = t << (32 - shift)
	;;
	-> trim(a)
}

/* trims leading zeros */
const trim = {a
	var i

	for i = a.dig.len; i > 0; i--
		if a.dig[i - 1] != 0
			break
		;;
	;;
	a.dig = slgrow(a.dig, i)
	if i == 0
		a.sign = 0
	elif a.sign == 0
		a.sign = 1
	;;
	-> a
}