shithub: mc

ref: 964632bdffbee7760477e2faa3c862a18528f518
dir: /libstd/bigint.myr/

View raw version
use "alloc.use"
use "cmp.use"
use "die.use"
use "extremum.use"
use "fmt.use"
use "option.use"
use "slcp.use"
use "sldup.use"
use "slpush.use"
use "types.use"
use "utf.use"

pkg std =
	type bigint = struct
		dig	: uint32[:] 	/* little endian, no leading zeros. */
		sign	: int		/* -1 for -ve, 0 for zero, 1 for +ve. */
	;;

	/* administrivia */
	const mkbigint	: (v : int32 -> bigint#)
	const bigfree	: (a : bigint# -> void)
	const bigdup	: (a : bigint# -> bigint#)
	const bigparse	: (a : bigint# -> option(byte[:]))
	const bigfmt	: (b : byte[:], a : bigint# -> size)

	/* some useful predicates */
	const bigiszero	: (a : bigint# -> bool)
	const bigcmp	: (a : bigint#, b : bigint# -> order)

	/* bigint*bigint -> bigint ops */
	const bigadd	: (a : bigint#, b : bigint# -> bigint#)
	const bigsub	: (a : bigint#, b : bigint# -> bigint#)
	const bigmul	: (a : bigint#, b : bigint# -> bigint#)
	const bigdiv	: (a : bigint#, b : bigint# -> bigint#)
	const bigmod	: (a : bigint#, b : bigint# -> bigint#)
	const bigdivmod	: (a : bigint#, b : bigint# -> [bigint#, bigint#])
	const bigshl	: (a : bigint#, b : bigint# -> bigint#)
	const bigshr	: (a : bigint#, b : bigint# -> bigint#)
	const bigshra	: (a : bigint#, b : bigint# -> bigint#)

	/* bigint*int -> bigint ops */
	const bigaddi	: (a : bigint#, b : int64 -> bigint#)
	const bigsubi	: (a : bigint#, b : int64 -> bigint#)
	const bigmuli	: (a : bigint#, b : int64 -> bigint#)
	const bigdivi	: (a : bigint#, b : int64 -> bigint#)
	const bigshli	: (a : bigint#, b : uint64 -> bigint#)
	const bigshri	: (a : bigint#, b : uint64 -> bigint#)
	const bigshrai	: (a : bigint#, b : uint64 -> bigint#)
;;

const mkbigint = {v
	var a
	a = zalloc()

	a.dig = slalloc(1)
	if v < 0
		a.sign = -1
		v = -v
	elif v > 0
		a.sign = 1
	;;
	a.dig[0] = (v castto(uint32))
	-> trim(a)
}

const bigfree = {a
	slfree(a.dig)
	free(a)
}

const bigdup = {a
	var v

	v = zalloc()
	v.dig = sldup(a.dig)
	v.sign = a.sign
	-> v
}

/* for now, just dump out something for debugging... */
const bigfmt = {buf, val
	const digitchars = ['0','1','2','3','4','5','6','7','8','9','a','b','c','d','e','f']
	var v
	var n, i
	var tmp
	var rem

	n = 0
	if val.sign == 0
		n += encode(buf, '0')
		-> n
	elif val.sign == -1
		n += encode(buf, '-')
	;;

	val = bigdup(val)
	/* generate the digits in reverse order */
	while !bigiszero(val)
		(v, rem) = bigdivmod(val, mkbigint(10))
		if rem.dig.len > 0
			n += bfmt(buf[n:], "%c", digitchars[rem.dig[0]])
		else
			n += bfmt(buf[n:], "0")
		;;
		bigfree(val)
		bigfree(rem)
		val = v
	;;
	bigfree(val)
	/* we only generated ascii digits, so this works. ugh. */
	for i = 0; i < n/2; i++
		tmp = buf[i]
		buf[i] = buf[n - i - 1]
		buf[n - i - 1] = tmp
	;;
	-> n
}

/*
const bigparse = {str
	var c, i
	var val, base
	var a

	if hasprefix(str, "0x") || hasprefix(str, "0X")
		base = 16
	if hasprefix(str, "0o") || hasprefix(str, "0O")
		base = 8
	if hasprefix(str, "0b") || hasprefix(str, "0B")
		base = 2
	else
		base = 10
	;;

	a = mkbigint()
	while str.len != 0
		(c, str) = striter(str)
		val = charval(c)
		if val < 0
			bigfree(a)
			-> `None
		;;
		bigmuli(a, base)
		bigaddi(a, val)
	;;
	-> `Some a
}
*/

const bigiszero = {v
	-> v.sign == 0
}

const bigcmp = {a, b
	var i

	if a.sign < b.sign
		-> `Before
	elif a.sign > b.sign
		-> `After
	else
		/* the one with more digits has greater magnitude */
		if a.dig.len > b.dig.len
			-> signedmagorder(a.sign)
		;;
		/* otherwise, the one with the first larger digit is bigger */
		for i = a.dig.len; i > 0; i--
			if a.dig[i - 1] > b.dig[i - 1]
				-> signedmagorder(a.sign)
			elif b.dig[i - 1] > a.dig[i - 1]
				-> signedmagorder(a.sign)
			;;
		;;
	;;
	-> `Equal
}

const signedmagorder = {sign
	if sign < 0
		-> `Before
	else
		-> `After
	;;
}

/* a += b */
const bigadd = {a, b
	if a.sign == b.sign
		-> uadd(a, b)
	else
		match bigcmp(a, b)
		| `Before: /* a is negative */
		    a.sign = b.sign
		    -> usub(b, a)
		| `After: /* b is negative */
		    -> usub(a, b)
		| `Equal:
			die("Impossible. Equal vals with different sign.")
		;;
	;;
}

/* adds two unsigned values together. */
const uadd = {a, b
	var v, i
	var carry
	var n

	carry = 0
	n = min(a.dig.len, b.dig.len)
	/* guaranteed to carry no more than one value */
	a.dig = slpush(a.dig, 0)
	for i = 0; i < n; i++
		v = (a.dig[i] castto(uint64)) + (b.dig[i] castto(uint64)) + carry;
		if v > (0xffffffff castto(uint64))
			carry = 1
		else
			carry = 0
		;;
		a.dig[i] = v castto(uint32)
	;;
	a.dig[i] += carry castto(uint32)
	-> trim(a)
}

/* a -= b */
const bigsub = {a, b
	if a.sign != b.sign
		-> uadd(a, b)
	else
		match bigcmp(a, b)
		| `Before: /* a is negative */
		    a.sign = b.sign
		    -> usub(b, a)
		| `After: /* b is negative */
		    -> usub(a, b)
		| `Equal:
			die("Impossible. Equal vals with different sign.")
		;;
	;;
	-> a
}

/* subtracts two unsigned values, where 'a' is strictly greater than 'b' */
const usub = {a, b
	var carry
	var v, i

	carry = 0
	for i = 0; i < a.dig.len; i++
		v = (a.dig[i] castto(int64)) - (b.dig[i] castto(int64)) - carry
		if v < 0
			carry = 1
		else
			carry = 0
		;;
		a.dig[i] = v castto(uint32)
	;;
	-> trim(a)
}

/* a *= b */
const bigmul = {a, b
	var i, j
	var ai, bj, wij
	var carry, t
	var w

	if a.sign != b.sign
		a.sign = -1
	else
		a.sign = 1
	;;
	w  = slzalloc(a.dig.len + b.dig.len)
	for j = 0; j < b.dig.len; j++
		if a.dig[j] == 0
			w[j] = 0
			continue
		;;
		carry = 0
		for i = 0; i < a.dig.len; i++
			ai = a.dig[i] castto(uint64)
			bj = b.dig[j] castto(uint64)
			wij = w[i+j] castto(uint64)
			t = ai * bj + wij + carry
			w[i + j] = (t castto(uint32))
			carry = t >> 32
		;;
		w[i+j] = carry castto(uint32)
	;;
	slfree(a.dig)
	a.dig = w
	-> trim(a)
}

const bigdiv = {a : bigint#, b : bigint# -> bigint#
	var q, r

	(q, r) = bigdivmod(a, b)
	bigfree(r)
	slfree(a.dig)
	a.dig = q.dig
	free(q)
	-> a
}
const bigmod = {a : bigint#, b : bigint# -> bigint#
	var q, r

	(q, r) = bigdivmod(a, b)
	bigfree(q)
	slfree(a.dig)
	a.dig = r.dig
	free(r)
	-> a
}

/* a /= b */
const bigdivmod = {a : bigint#, b : bigint# -> [bigint#, bigint#]
	/*var u, v /* normalized a and b */*/
	var q/*, qhat /* quotient and estimated quotient */*/
	/*
	var rhat /* remainder */
	var s, i, j, t
	var p
	*/
	var j
	var carry, b0, aj

	if bigiszero(b)
		die("divide by zero\n")
	;;

	q = zalloc()
	q.dig = slzalloc(max(a.dig.len, b.dig.len))
	if a.sign != b.sign
		q.sign = -1
	else
		q.sign = 1
	;;

	/* handle single digit divisor separately for performance */
	if b.dig.len == 1
		carry = 0
		b0 = (b.dig[0] castto(uint64))
		for j = a.dig.len; j > 0; j--
			aj = (a.dig[j - 1] castto(uint64))
			q.dig[j - 1] = (((carry << 32) + aj)/b0) castto(uint32)
			carry = (carry << 32) + aj - (q.dig[j-1] castto(uint64))*b0
		;;
		for v in q.dig
		;;
		-> (trim(q), trim(mkbigint(carry castto(int32))))
	;;
	die("big bigint division not implemented\n")
	-> (trim(a), mkbigint(carry castto(int32)))
}

/* a <<= b */
const bigshl = {a, b
	match b.dig.len
	| 0:	-> a
	| 1:	-> bigshli(a, b.dig[0] castto(uint64))
	| n:	die("shift by way too much\n")
	;;
}

/* a >>= b, unsigned */
const bigshr = {a, b
	match b.dig.len
	| 0:	-> a
	| 1:	-> bigshri(a, b.dig[0] castto(uint64))
	| n:	die("shift by way too much\n")
	;;
}

/* a >>= b, sign extending */
const bigshra = {a, b
	match b.dig.len
	| 0:	-> a
	| 1:	-> bigshrai(a, b.dig[0] castto(uint64))
	| n:	die("shift by way too much\n")
	;;
}

const trim = {a
	var i

	for i = a.dig.len; i > 0; i--
		if a.dig[i - 1] != 0
			break
		;;
	;;
	a.dig = slgrow(a.dig, i)
	if i == 0
		a.sign = 0
	;;
	-> a
}

const bigshli = {a, s
	var off, shift
	var t, carry
	var i

	off = s/32
	shift = s % 32

	a.dig = slzgrow(a.dig, 1 + a.dig.len + off castto(size))
	/* blit over the base values */
	for i = a.dig.len; i > off; i--
		a.dig[i - 1] = a.dig[i - 1 - off]
	;;
	for i = 0; i < off; i++
		a.dig[i] = 0
	;;
	/* and shift over by the remainder */
	carry = 0
	for i = 0; i < a.dig.len; i++
		t = (a.dig[i] castto(uint64)) << shift
		a.dig[i] = (t | carry) castto(uint32) 
		carry = t >> 32
	;;
	-> trim(a)
}

const bigshri = {a, s
	-> bigshrfill(a, s, 0)
}

const bigshrai = {a, s
	if a.sign == -1
		-> bigshrfill(a, s, ~0)
	else
		-> bigshrfill(a, s, 0)
	;;
}

const bigshrfill = {a, s, fill
	var off, shift
	var t, carry
	var i

	off = s/32
	shift = s % 32

	/* blit over the base values */
	for i = 0; i < a.dig.len - off; i++
		a.dig[i] = a.dig[i + off]
	;;
	for i = a.dig.len; i < a.dig.len + off; i++
		a.dig[i] = fill
	;;
	/* and shift over by the remainder */
	carry = 0
	for i = a.dig.len; i > 0; i--
		t = (a.dig[i - 1] castto(uint64))
		a.dig[i - 1] = (carry | (t >> shift)) castto(uint32) 
		carry = t << (32 - shift)
	;;
	-> trim(a)
}