shithub: mc

ref: 1f7f58d6a483853714d65099f32a916a92b18197
dir: /compile.myr/

View raw version
use std

use "types.use"
use "ranges.use"

pkg regex =
	const compile	: (re : byte[:] -> std.error(regex#, status))
	const dbgcompile	: (re : byte[:] -> std.error(regex#, status))
	const free	: (re : regex# -> void)
;;

type tree = union
	/* basic string building */
	`Alt	[tree#, tree#]
	`Cat	[tree#, tree#]

	/* repetition */
	`Star	tree#
	`Plus	tree#
	`Quest	tree#	

	/* end matches */
	`Byte	byte
	`Chr	char
	`Class	[char, char]

	/* meta */
	`Cap	tree#
	`Bol	/* beginning of line */
	`Eol	/* end of line */
;;

type parseresult = union
	`Some tree#
	`None
	`Fail status
;;

/* Compiles a pattern into a regex */
const compile = {pat
	-> regexcompile(std.zalloc(), pat)
}

/* Compiles a pattern into a debug regex. This can be verbose. */
const dbgcompile = {pat
	var re

	re = std.zalloc()
	re.debug = true
	-> regexcompile(re, pat)
}

/* compiles a pattern into an allocated regex */
const regexcompile = {re, pat
	re.pat = pat
	re.nmatch = 1 /* whole match */
	match parse(re)
	| `None:	-> `std.Failure (`Earlystop)
	| `Fail f:	-> `std.Failure f
	| `Some t:
		/*
		we can stop early if we get 
		an incorrectly encoded char
		*/
		if re.pat.len > 0
			-> `std.Failure (`Earlystop)
		;;
		dump(re, t, 0)
		append(re, `Ilbra 0)
		gen(re, t)
		append(re, `Irbra 0)
		append(re, `Imatch)
		idump(re)
		astfree(t)
		-> `std.Success re
	;;
	-> `std.Failure (`Noimpl)
}

const free = {re
	/* all the threads should be dead,
	 so we shouldn't have to free any*/
	std.slfree(re.prog)
	std.free(re)
}


/* generates bytecode from an AST */
const gen = {re, t
	var m

	match t#
	|`Alt	(a, b): genalt(re, a, b)
	|`Cat	(a, b): gen(re, a); gen(re, b)
	/* repetition */
	|`Star	a:	genstar(re, a)
	|`Plus	a:	gen(re, a); genstar(re, a)
	|`Quest	a:	genquest(re, a)

	/* end matches */
	|`Byte	b: 	append(re, `Ibyte b)
	|`Chr	c:	genchar(re, c)
	|`Class  (a, b):	genrange(re, a, b)

	/* meta */
	|`Bol:	append(re, `Ibol)
	|`Eol:	append(re, `Ibol)
	|`Cap	a:
		m = re.nmatch++
		append(re, `Ilbra m)
		gen(re, a)
		append(re, `Irbra m)
	;;
	-> re.proglen
}

/*
 converts a codepoint range spanning multiple utf8 byte lenghts into a
 set of utf8 ranges. Eg:
 	[0x00-0x2000]  => [0x00-0x7F]|[0xC2-0xDF][0x80-0x8F]
*/
const genrange = {re, lo, hi
	/* the transitions between different char lenghts for unicode
	   characters, needed so that we know how to generate the
	   different size categories */
	var charrng = [
		0,
		0x80,
		0x800,
		0x10000,
		0x200000,
		-1
	]
	var lbuf : byte[4], hbuf : byte[4]
	var lsz, hsz
	var sz, end
	var d
	var i, j

	lsz = std.charlen(lo)
	hsz = std.charlen(hi)
	charrng[lsz - 1] = lo
	charrng[hsz] = hi
	if lsz == 1 && hsz == 1
		append(re, `Irange (lo castto(byte), hi castto(byte)))
	else
		for i = hsz; i > lsz; i--
			if re.debug
				std.put("range size = %z\n", i - 2)
			;;
			d = re.proglen + i - 1
			append(re, `Ifork (re.proglen + 1, jmpdist(i) + d))
		;;
		end = re.proglen + jmpdist(hsz + 1);
		for i = 0; i < hsz; i++
			if re.debug
				std.put("lo[%z] = %i\n", i, charrng[i] castto(int))
				std.put("hi[%z] = %i\n", i, (charrng[i + 1] - 1) castto(int))
			;;
			sz = std.encode(lbuf[:], charrng[i])
			std.encode(hbuf[:], charrng[i + 1] - 1)
			for j = 0; j < sz; j++
				append(re, `Irange (lbuf[j], hbuf[j]))
			;;
			append(re, `Ijmp (end))
		;;
	;;
	-> re.proglen
}

/* calculates the forward jump distance for a utf8 character range */
const jmpdist = {n
	var d
	var i

	d = n - 1
	for i = n - 1; i > 0; i--
		d += i
	;;
	-> d
}

/* generates an alternation */
const genalt = {re, l, r
	var alt
	var jmp
	var l0
	var l1
	var l2

	alt 	= re.proglen
	l0	= append(re, `Ifork (-1, -1)) /* needs to be replaced */
		  gen(re, l)
	jmp	= re.proglen
	l1 	= append(re, `Ijmp -1) /* needs to be replaced */
	l2	= gen(re, r)

	re.prog[alt] = `Ifork(l0, l1)
	re.prog[jmp] = `Ijmp l2
	-> re.proglen
}

/* generates a repetition operator */
const genstar = {re, rep
	var alt
	var jmp
	var l0
	var l1
	var l2

	l0 	= re.proglen
	alt	= re.proglen
	l1 	= append(re, `Ifork (-1, -1)) /* needs to be replaced */
	jmp	= gen(re, rep)
	l2	= append(re, `Ijmp -1)


	re.prog[alt] = `Ifork (l1, l2)
	re.prog[jmp] = `Ijmp l0
	-> re.proglen
}

/* generates a question mark operator */
const genquest = {re, q
	var alt
	var l0
	var l1

	alt	= re.proglen
	l0	= append(re, `Ifork (-1, -1)) /* needs to be replaced */
	l1	= gen(re, q)
	re.prog[alt] = `Ifork (l0, l1)
	-> re.proglen
}

/* generates a single char match */
const genchar = {re, c
	var b : byte[4]
	var n
	var i

	n = std.encode(b[:], c)
	for i = 0; i < n; i++
		append(re, `Ibyte b[i])
	;;
	-> re.proglen
}

/* appends an instructon to an re program */
const append = {re, insn
	if re.proglen == re.prog.len
		re.prog = std.slgrow(re.prog, std.max(1, 2*re.proglen))
	;;
	re.prog[re.proglen] = insn
	re.proglen++
	-> re.proglen
}

/* instruction dump */
const idump = {re
	var i

	if !re.debug
		->
	;;
	for i = 0; i < re.proglen; i++
		std.put("%i:\t", i)
		match re.prog[i]
		/* Char matching. Consume exactly one byte from the string. */
		| `Ibyte b:		std.put("`Ibyte %b (%c)\n", b, b castto(char)) 
		| `Irange (start, end):	std.put("`Irange (%b,%b)\n", start, end) 
		/* capture groups */
		| `Ilbra m:		std.put("`Ilbra %z\n", m) 
		| `Irbra m:		std.put("`Irbra %z\n", m) 
		/* anchors */
		| `Ibol:			std.put("`Ibol\n")
		| `Ieol:			std.put("`Ieol\n")
		/* control flow */
		| `Ifork	(lip, rip):	std.put("`Ifork (%z,%z)\n", lip, rip) 
		| `Ijmp ip:		std.put("`Ijmp %z\n", ip) 
		| `Imatch:		std.put("`Imatch\n") 
		;;
	;;
}

/* AST dump */
const dump = {re, t, indent
	var i

	if !re.debug
		->
	;;
	for i = 0; i < indent; i++
		std.put("  ")
	;;
	match t#
	| `Alt	(a, b):
		std.put("Alt\n")
		dump(re, a, indent + 1)
		dump(re, b, indent + 1)
	| `Cat	(a, b):
		std.put("Cat\n")
		dump(re, a, indent + 1)
		dump(re, b, indent + 1)
	/* repetition */
	| `Star	a:
		std.put("Star\n")
		dump(re, a, indent + 1)
	| `Plus	a:
		std.put("Plus\n")
		dump(re, a, indent + 1)
	| `Quest	a:
		std.put("Quest\n")
		dump(re, a, indent + 1)
	| `Bol:
		std.put("Bol\n")
	| `Eol:
		std.put("Eol\n")
	/* end matches */
	| `Byte	b:
		std.put("Byte %b\n", b)
	| `Chr	c:
		std.put("Char %c\n", c)
	| `Class (a, b):
		std.put("Class (%c-%c)\n", a, b)

	/* meta */
	| `Cap	a:
		std.put("Cap\n")
		dump(re, a, indent + 1)
	;;
}

/* parses an expression */
const parse = {re
	match altexpr(re)
	| `Some t:
		if re.pat.len == 0
			-> `Some t
		else
			astfree(t)
			-> `Fail (`Earlystop)
		;;
	| `None:
		-> `None
	;;
}

const altexpr = {re
	var ret : tree#

	match catexpr(re)
	| `Some t:
		ret = t
		if matchc(re, '|')
			match altexpr(re)
			| `Some rhs:
				ret = mk(`Alt (ret, rhs))
			| `None:
				astfree(ret)
				-> `Fail (`Earlystop)
			| `Fail f:
				-> `Fail f
			;;
		;;
	| other:
		-> other
	;;
	-> `Some ret
}

const catexpr = {re
	var ret

	match repexpr(re)
	| `Some t: 
		ret = t
		match catexpr(re)
		| `Some rhs:
			ret = mk(`Cat (t, rhs))
		| `Fail f:	-> `Fail f
		| `None:	/* nothing */
		;;
	| other:
		-> other
	;;
	-> `Some ret
}

const repexpr = {re
	var ret

	match baseexpr(re)
	| `Some t:
		if matchc(re, '*')
			ret = mk(`Star t)
		elif matchc(re, '+')
			ret = mk(`Plus t)
		elif matchc(re, '?')
			ret = mk(`Quest t)
		else
			ret = t
		;;
	| other:
		-> other
	;;
	-> `Some ret
}

const baseexpr = {re
	var ret

	if re.pat.len == 0
		-> `None
	;;
	match peekc(re)
	/* lower prec operators */
	| '|':	-> `None
	| ')':	-> `None
	| '*':	-> `Fail (`Badrep)
	| '+':	-> `Fail (`Badrep)
	| '?':	-> `Fail (`Badrep)
	| '[':	-> chrclass(re)
	| '.':	getc(re); ret = mk(`Class (0, std.Maxcharval))
	| '^':	getc(re); ret = mk(`Bol)
	| '$':	getc(re); ret = mk(`Eol)
	| '(':	
		getc(re)
		match altexpr(re)
		| `Some s:	ret = mk(`Cap s)
		| `None:	-> `Fail (`Emptyparen)
		;;
		if !matchc(re, ')')
			astfree(ret)
			-> `Fail (`Unbalanced)
		;;
	| '\\':
		getc(re) /* consume the slash */
		if re.pat.len == 0
			-> `Fail (`Earlystop)
		;;
		-> escaped(re)
	| c:
		getc(re)
		ret = mk(`Chr c)
	;;
	-> `Some ret
}

const escaped = {re
	var ret

	match getc(re)
	/* character classes */
	| 'd': ret = `Some ranges(re, _ranges.tabasciidigit[:])
	| 'x': ret = `Some ranges(re, _ranges.tabasciixdigit[:])
	| 's': ret = `Some ranges(re, _ranges.tabasciispace[:])
	| 'w': ret = `Some ranges(re, _ranges.tabasciiword[:])
	| 'h': ret = `Some ranges(re, _ranges.tabasciiblank[:])

	/* negated character classes */
	| 'W': ret = `Some negranges(re, _ranges.tabasciiword[:])
	| 'S': ret = `Some negranges(re, _ranges.tabasciispace[:])
	| 'D': ret = `Some negranges(re, _ranges.tabasciidigit[:])
	| 'X': ret = `Some negranges(re, _ranges.tabasciixdigit[:])
	| 'H': ret = `Some negranges(re, _ranges.tabasciiblank[:])

	/* unicode character classes */
	| 'p':	ret = unicodeclass(re, false)
	| 'P':  ret = unicodeclass(re, true)

	/* escaped metachars */
	| '^': ret = `Some mk(`Chr '^')
	| '$': ret = `Some mk(`Chr '$')
	| '.': ret = `Some mk(`Chr '.')
	| '+': ret = `Some mk(`Chr '+')
	| '?': ret = `Some mk(`Chr '?')
	;;
	-> ret
}

const unicodeclass = {re, neg
	var c, s
	var tab
	var n

	if re.pat.len == 0
		-> `Fail (`Earlystop)
	;;
	n = 0
	s = re.pat
	/* either a single char pattern, or {pat} */
	match getc(re)
	| '{':
		while re.pat.len > 0
			c = getc(re)
			if c == '}'
				break
			;;
			n += std.charlen(c)
		;;
	| r:
		n += std.charlen(r)
	;;
	s = s[:n]
	/* letters */
	if std.sleq(s, "L") || std.sleq(s, "Letter")
		tab = _ranges.tabalpha[:]
	elif std.sleq(s, "Lu") || std.sleq(s, "Uppercase_Letter")
		tab = _ranges.tabupper[:]
	elif std.sleq(s, "Ll") || std.sleq(s, "Lowercase_Letter")
		tab = _ranges.tablower[:]
	elif std.sleq(s, "Lt") || std.sleq(s, "Titlecase_Letter")
		tab = _ranges.tablower[:]
	/* numbers (incomplete) */
	elif std.sleq(s, "N") || std.sleq(s, "Number")
		tab = _ranges.tabdigit[:]
	elif std.sleq(s, "Z") || std.sleq(s, "Separator")
		tab = _ranges.tabspace[:]
	elif std.sleq(s, "Zs") || std.sleq(s, "Space_Separator")
		tab = _ranges.tabblank[:]
	else
		-> `Fail (`Badrange)
	;;
	if !neg
		-> `Some ranges(re, tab)
	else
		-> `Some negranges(re, tab)
	;;
}

const chrclass = {re
	var rl, m
	var neg
	var t

	/* we know we saw '[' on entry */
	matchc(re, '[')
	neg = false
	if matchc(re, '^')
		neg = true
	;;
	rl = rangematch(re, [][:])
	while peekc(re) != ']'
		rl = rangematch(re, rl)
	;;
	if !matchc(re, ']')
		std.slfree(rl)
		-> `Fail (`Earlystop)
	;;
	if neg
		std.sort(rl, {a, b;
			if a[0] < b[0]
				-> `std.Before
			elif a[0] == b[0]
				-> `std.Equal
			else
				-> `std.After
			;;})
		m = merge(rl)
		t = negranges(re, m)
		std.slfree(m)
	else
		t = ranges(re, rl)
	;;
	std.slfree(rl)
	-> `Some t
}

const rangematch = {re, sl
	var lo
	var hi

	lo = getc(re)
	if matchc(re, '-')
		hi = getc(re)
		if lo <= hi
			-> std.slpush(sl, [lo, hi])
		else
			-> std.slpush(sl, [hi, lo])
		;;
	else
		-> std.slpush(sl, [lo, lo])
	;;
}

const ranges = {re, rng
	var ret
	var lhs
	var rhs

	if rng.len == 1
		ret = mk(`Class (rng[0][0], rng[0][1]))
	else
		lhs = ranges(re, rng[0:rng.len/2])
		rhs = ranges(re, rng[rng.len/2:rng.len])
		ret = mk(`Alt (lhs, rhs))
	;;
	-> ret
}

const negranges = {re, rng
	var neg, ret

	neg = negate(rng)
	ret = ranges(re, neg)
	std.slfree(neg)
	-> ret
}

const negate = {rng
	var start, end, next
	var neg

	neg = [][:]
	start = 0
	next = 0 /* if we have no ranges */
	for r in rng
		(end, next) = (r[0], r[1])
		neg = std.slpush(neg, [start, end - 1])
		start = next + 1
	;;
	neg = std.slpush(neg, [next + 1, std.Maxcharval])
	-> neg
}

/* rl is a sorted list of ranges */
const merge = {rl
	var lo, hi
	var ret

	if rl.len == 0
		-> [][:]
	;;
	ret = [][:]
	lo = rl[0][0]
	hi = rl[0][1]
	rl = rl[1:] /* BUG: compiler wants an rval in loop range */
	for r in rl
		/* if it overlaps or abuts, merge */
		if r[0] <= hi + 1
			hi = r[1]
		else
			ret = std.slpush(ret, [lo, hi])
			lo = r[0]
			hi = r[1]
		;;
	;;
	-> std.slpush(ret, [lo, hi])
}


const matchc = {re, c
	var str
	var chr

	(chr, str) = std.striter(re.pat)
	if chr != c
		-> false
	;;
	re.pat = str
	-> true
}

const getc = {re
	var c

	(c, re.pat) = std.striter(re.pat)
	-> c
}

const peekc = {re
	var c
	var _

	(c, _) = std.striter(re.pat)
	-> c
}

const mk = {v
	var t

	t = std.alloc()
	t# = v
	-> t
}

const astfree = {t
	match t#
	| `Alt	(a, b): astfree(a); astfree(b)
	| `Cat	(a, b): astfree(a); astfree(b)
	/* repetition */
	| `Star	a:	astfree(a)
	| `Plus	a:	astfree(a)
	| `Quest	a:	astfree(a)

	/* end matches */
	| `Byte	b:	
	| `Chr	c:	
	| `Class (a, b):	

	/* meta */
	| `Cap	a:	astfree(a)
	;;
	std.free(t)
}