shithub: mc

ref: 98235d9673f41003b14a0e8852785c6582fdff8f
dir: /lib/std/bitset.myr/

View raw version
use "alloc"
use "die"
use "extremum"
use "hashfuncs"
use "mk"
use "slcp"
use "sldup"
use "slfill"
use "traits"
use "types"

pkg std =
	type bitset = struct
		bits : size[:]
	;;

	const mkbs	: (-> bitset#)
	const bsdup	: (bs : bitset# -> bitset#)
	const bsfree	: (bs : bitset# -> void)
	const bsclear	: (bs : bitset# -> bitset#)

	const bsmax	: (a : bitset# -> size)
	const bscount	: (a : bitset# -> size)

	generic bsput	: (bs : bitset#, v : @a -> bool) :: integral,numeric @a
	generic bsdel	: (bs : bitset#, v : @a -> bool) :: integral,numeric @a
	generic bshas	: (bs : bitset#, v : @a -> bool) :: integral,numeric @a

	const bsdiff	: (a : bitset#, b : bitset# -> void)
	const bsintersect	: (a : bitset#, b : bitset# -> void)
	const bsunion	: (a : bitset#, b : bitset# -> void)
	const bsissubset	: (a : bitset#, b : bitset# -> bool)

	impl equatable bitset#
	impl hashable bitset#

	type bsiter = struct
		idx	: size
		bs	: bitset#
	;;

	impl iterable bsiter -> size
	const  bybsvalue	: (bs : bitset# -> bsiter)
;;

const mkbs = {
	-> zalloc()
}

const bsdup = {bs
	-> mk([.bits=sldup(bs.bits)])
}

const bsfree = {bs
	slfree(bs.bits)
	free(bs)
}

const bsclear = {bs
	slfill(bs.bits, 0)
	-> bs
}

const bsmax = {bs
	-> bs.bits.len * sizeof(size) * 8
}

const bscount = {bs
	var n

	n = 0
	for v : bybsvalue(bs)
		n++
	;;
	-> n
}

generic bsput = {bs, v
	var idx, off, changed

	idx = (v : size) / (8*sizeof(size))
	off = (v : size) % (8*sizeof(size))
	ensurelen(bs, idx)

	changed = (bs.bits[idx] & (1 << off)) == 0
	bs.bits[idx] |= (1 << off)
	-> changed
}

generic bsdel = {bs, v
	var idx, off, changed

	changed = false
	idx = (v : size) / (8*sizeof(size))
	off = (v : size) % (8*sizeof(size))
	if idx < bs.bits.len
		changed = (bs.bits[idx] & (1 << off)) != 0
		bs.bits[idx] &= ~(1 << off)
	;;
	-> changed
}

generic bshas = {bs, v
	var idx
	var off

	idx = (v : size) / (8*sizeof(size))
	off = (v : size) % (8*sizeof(size))
	if idx >= bs.bits.len
		-> false
	;;
	-> (bs.bits[idx] & (1 << off)) != 0
}

const bsunion = {a, b
	eqsz(a, b)
	for var i = 0; i < b.bits.len; i++
		a.bits[i] |= b.bits[i]
	;;
}

const bsintersect = {a, b
	var n

	n = min(a.bits.len, b.bits.len)
	for var i = 0; i < n; i++
		a.bits[i] &= b.bits[i]
	;;
}

const bsdiff = {a, b
	var n

	n = min(b.bits.len, a.bits.len)
	for var i = 0; i < n; i++
		a.bits[i] &= ~b.bits[i]
	;;
}

const bsissubset = {a, b
	eqsz(a, b);
	for var i = 0; i < a.bits.len; i++
		if (b.bits[i] & a.bits[i]) != b.bits[i]
			-> false
		;;
	;;
	-> true
}

impl equatable bitset# =
	eq = {a, b
		eqsz(a, b)
		for var i = 0; i < a.bits.len; i++
			if a.bits[i] != b.bits[i]
				-> false
			;;
		;;
		-> true
	}
;;

impl hashable bitset# =
	hash = {a
		var end : size

		end = a.bits.len
		while end > 0 && a.bits[end - 1] == 0
			end--
		;;
		-> std.hash(a.bits[:end])
	}
;;

const ensurelen = {bs, len
	if bs.bits.len <= len
		slzgrow(&bs.bits, len + 1)
	;;
}

const eqsz = {a, b
	var sz

	sz = max(a.bits.len, b.bits.len)
	ensurelen(a, sz)
	ensurelen(b, sz)
}

const bybsvalue = {bs
	-> [.idx=0, .bs=bs]
}

impl iterable bsiter -> size =
	__iternext__ = {itp, valp
		var bs = itp.bs
		var n = bsmax(bs)
		for var i = itp.idx; i < n; i++
			/* fast forward by whole chunks */
			while i < n && bs.bits[i / (8*sizeof(size))] == 0
				i = (i + 8*sizeof(size)) & ~(8*sizeof(size) - 1)
			;;
			if bshas(bs, i)
				itp.idx = i + 1
				valp# = i
				-> true
			;;
		;;
		itp.idx = n
		-> false
	}

	__iterfin__ = {itp, valp
	}
;;