shithub: mc

ref: b9656ec7e120d1fced20ea5ca826c57f2de4fc84
dir: /lib/math/fma-impl.myr/

View raw version
use std

use "util"

pkg math =
	pkglocal const fma32 : (x : flt32, y : flt32, z : flt32 -> flt32)
	pkglocal const fma64 : (x : flt64, y : flt64, z : flt64 -> flt64)
;;

const exp_mask32 : uint32 = 0xff << 23
const exp_mask64 : uint64 = 0x7ff << 52

pkglocal const fma32 = {x : flt32, y : flt32, z : flt32
	var xn, yn
	(xn, _, _) = std.flt32explode(x)
	(yn, _, _) = std.flt32explode(y)
	var xd : flt64 = flt64fromflt32(x)
	var yd : flt64 = flt64fromflt32(y)
	var zd : flt64 = flt64fromflt32(z)
	var prod : flt64 = xd * yd
	var pn, pe, ps
	(pn, pe, ps) = std.flt64explode(prod)
	if pe == -1023
		pe = -1022
	;;
	if pn != (xn != yn)
		/* In case of NaNs, sign might not have been preserved */
		pn = (xn != yn)
		prod = std.flt64assem(pn, pe, ps)
	;;

	var r : flt64 = prod + zd
	var rn, re, rs
	(rn, re, rs) = std.flt64explode(r)

	/*
	   At this point, r is probably the correct answer. The
	   only issue is the rounding.

	   Ex 1: If x*y > 0 and z is a tiny, negative number, then
	   adding z probably does no rounding. However, if
	   truncating to 23 bits of precision would cause round-to-even,
	   and that round would be upwards, then we need to remember
	   those trailing bits of z and cancel the rounding.

	   Ex 2: If x, y, z > 0, and z is small, with
	                 last bit in flt64 |
	          last bit in flt32 v      v
	   x * y = ...............101011..11
	       z =                          10000...,
	   then x * y + z will be rounded to
	           ...............101100..00,
	   and then as a flt32 it will become
	           ...............110,
	   Even though, looking at the original bits, it doesn't
	   "deserve" the final rounding.

	   These can only happen if r is non-inf, non-NaN, and the
	   lower 29 bits correspond to "exactly halfway".
	 */
	if re == 1024 || rs & 0x1fffffff != 0x10000000
		-> flt32fromflt64(r)
	;;

	/*
	   At this point, a rounding is about to happen. We need
	   to know what direction that rounding is, so that we can
	   tell if it's wrong. +1 means "away from 0", -1 means
	   "towards 0".
	 */
	var zn, ze, zs
	(zn, ze, zs) = std.flt64explode(zd)
	var round_direction = 0
	if rs & 0x20000000 == 0
		round_direction = -1
	else
		round_direction = 1
	;;

	var smaller, larger, smaller_e, larger_e
	if pe > ze || (pe == ze && ps > zs)
		(smaller, larger, smaller_e, larger_e) = (zs, ps, ze, pe)
	else
		(smaller, larger, smaller_e, larger_e) = (ps, zs, pe, ze)
	;;
	var mask = shr((-1 : uint64), 64 - std.min(64, larger_e - smaller_e))
	var prevent_rounding = false
	if (round_direction > 0 && pn != zn) || (round_direction < 0 && pn == zn)
		/*
		   The prospective rounding disagrees with the
		   signage. We are potentially in the case of Ex
		   1.

		   Look at the bits (of the smaller flt64) that are
		   outside the range of r. If there are any such
		   bits, we need to cancel the rounding.

		   We certainly need to consider bits very far to
		   the right, but there's an awkwardness concerning
		   the bit just outside the flt64 range: it governed
		   round-to-even, so it might have had an effect.
		   We only care about bits which did not have an
		   effect. Therefore, we perform the subtraction
		   using only the bits from smaller that lie in
		   larger's range, then check whether the result
		   is susceptible to round-to-even.

		   (Since we only care about the last bit, and the
		   base is 2, subtraction or addition are equally
		   useful.)
		*/
		if (larger ^ shr(smaller, larger_e - smaller_e)) & 0x1 == 0
			prevent_rounding = smaller & mask != 0
		;;
	else
		/*
		   The prospective rounding agrees with the signage.
		   We are potentially in the case of Ex 2.

		   We just need to check if r was obtained by
		   rounding in the addition step. In this case, we
		   still check the smaller/larger, and we only
		   care about round-to-even. Any
		   rounding that happened previously is enough
		   reason to disqualify this next rounding.
		*/
		prevent_rounding = (larger ^ shr(smaller, larger_e - smaller_e)) & 0x1 != 0
	;;

	if prevent_rounding
		if round_direction > 0
			rs--
		else
			rs++
		;;
	;;

	-> flt32fromflt64(std.flt64assem(rn, re, rs))
}

pkglocal const fma64 = {x : flt64, y : flt64, z : flt64
	var xn : bool, yn : bool, zn : bool
	var xe : int64, ye : int64, ze : int64
	var xs : uint64, ys : uint64, zs : uint64

	var xb : uint64 = std.flt64bits(x)
	var yb : uint64 = std.flt64bits(y)
	var zb : uint64 = std.flt64bits(z)

	/* check for both NaNs and infinities */
	if xb & exp_mask64 == exp_mask64 || \
	   yb & exp_mask64 == exp_mask64
		-> x * y + z
	elif z == 0.0 || z == -0.0 || x * y == 0.0 || x * y == -0.0
		-> x * y + z
	elif zb & exp_mask64 == exp_mask64
		-> z
	;;

	(xn, xe, xs) = std.flt64explode(x)
	(yn, ye, ys) = std.flt64explode(y)
	(zn, ze, zs) = std.flt64explode(z)
	if xe == -1023
		xe = -1022
	;;
	if ye == -1023
		ye = -1022
	;;
	if ze == -1023
		ze = -1022
	;;

        /* Keep product in high/low uint64s */
	var xs_h : uint64 = xs >> 32
	var ys_h : uint64 = ys >> 32
	var xs_l : uint64 = xs & 0xffffffff
	var ys_l : uint64 = ys & 0xffffffff

	var t_l : uint64 = xs_l * ys_l
	var t_m : uint64 = xs_l * ys_h + xs_h * ys_l
	var t_h : uint64 = xs_h * ys_h

	var prod_l : uint64 = t_l + (t_m << 32)
	var prod_h : uint64 = t_h + (t_m >> 32)
	if t_l > prod_l
		prod_h++
	;;

	var prod_n = xn != yn
	var prod_lastbit_e = (xe - 52) + (ye - 52)
	var prod_first1 = find_first1_64_hl(prod_h, prod_l, 105)
	var prod_firstbit_e = prod_lastbit_e + prod_first1

	var z_firstbit_e = ze
	var z_lastbit_e = ze - 52
	var z_first1 = 52

	/* subnormals could throw firstbit_e calculations out of whack */
	if (zb & exp_mask64 == 0)
		z_first1 = find_first1_64(zs, z_first1)
		z_firstbit_e = z_lastbit_e + z_first1
	;;

	var res_n
	var res_h = 0
	var res_l = 0
	var res_first1
	var res_lastbit_e
	var res_firstbit_e

	if prod_n == zn
		res_n = prod_n

		/*
		   Align prod and z so that the top bit of the
		   result is either 53 or 54, then add.
		 */
		if prod_firstbit_e >= z_firstbit_e
			/*
			    [ prod_h ][ prod_l ]
			         [ z...
			 */
			res_lastbit_e = prod_lastbit_e
			(res_h, res_l) = (prod_h, prod_l)
			(res_h, res_l) = add_shifted(res_h, res_l, zs, z_lastbit_e - prod_lastbit_e)
		else
			/*
			        [ prod_h ][ prod_l ]
			    [ z...
			 */
			res_lastbit_e = z_lastbit_e - 64
			res_h = zs
			res_l = 0
			if prod_lastbit_e >= res_lastbit_e + 64
				/* In this situation, prod must be extremely subnormal */
				res_h += shl(prod_l, prod_lastbit_e - res_lastbit_e - 64)
			elif prod_lastbit_e >= res_lastbit_e
				res_h += shl(prod_h, prod_lastbit_e - res_lastbit_e)
				res_h += shr(prod_l, res_lastbit_e + 64 - prod_lastbit_e)
				res_l += shl(prod_l, prod_lastbit_e - res_lastbit_e)
			elif prod_lastbit_e + 64 >= res_lastbit_e
				res_h += shr(prod_h, res_lastbit_e - prod_lastbit_e)
				var l1 = shl(prod_h, prod_lastbit_e + 64 - res_lastbit_e)
				var l2 = shr(prod_l, res_lastbit_e - prod_lastbit_e)
				res_l = l1 + l2
				if res_l < l1
					res_h++
				;;
			elif prod_lastbit_e + 128 >= res_lastbit_e
				res_l += shr(prod_h, res_lastbit_e - prod_lastbit_e - 64)
			;;
		;;
	else
		match compare_hl_z(prod_h, prod_l, prod_firstbit_e, prod_lastbit_e, zs, z_firstbit_e, z_lastbit_e)
		| `std.Equal: -> 0.0
		| `std.Before:
			/* prod > z */
			res_n = prod_n
			res_lastbit_e = prod_lastbit_e
			(res_h, res_l) = sub_shifted(prod_h, prod_l, zs, z_lastbit_e - prod_lastbit_e)
		| `std.After:
			/* z > prod */
			res_n = zn
			res_lastbit_e = z_lastbit_e - 64
			(res_h, res_l) = sub_shifted(zs, 0, prod_h, prod_lastbit_e + 64 - (z_lastbit_e - 64))
			(res_h, res_l) = sub_shifted(res_h, res_l, prod_l, prod_lastbit_e - (z_lastbit_e - 64))
		;;
	;;

	res_first1 = 64 + find_first1_64(res_h, 55)
	if res_first1 == 63
		res_first1 = find_first1_64(res_l, 63)
	;;
	res_firstbit_e = res_first1 + res_lastbit_e

	/*
	   Finally, res_h and res_l are the high and low bits of
	   the result. They now need to be assembled into a flt64.
	   Subnormals and infinities could be a problem.
	 */
	var res_s = 0
	if res_firstbit_e <= -1023
		/* Subnormal case */
		if res_lastbit_e + 128 < 12 - 1022
			res_s = shr(res_h, 12 - 1022 - (res_lastbit_e + 128))
			res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64))
		elif res_lastbit_e + 64 < 12 - 1022
			res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022))
			res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64))
		else
			res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022))
			res_s |= shl(res_l, -12 + (res_lastbit_e + 64) - (-1022))
		;;

		if need_round_away(res_h, res_l, res_first1 + (-1074 - res_firstbit_e))
			res_s++
		;;

		/* No need for exponents, they are all zero */
		var res = res_s
		if res_n
			res |= (1 << 63)
		;;
		-> std.flt64frombits(res)
	;;

	if res_firstbit_e >= 1024
		/* Infinity case */
		if res_n
			-> std.flt64frombits(0xfff0000000000000)
		else
			-> std.flt64frombits(0x7ff0000000000000)
		;;
	;;

	if res_first1 - 52 >= 64
		res_s = shr(res_h, (res_first1 : int64) - 64 - 52)
		if need_round_away(res_h, res_l, res_first1 - 52)
			res_s++
		;;
	elif res_first1 - 52 >= 0
		res_s = shl(res_h, 64 - (res_first1 - 52))
		res_s |= shr(res_l, res_first1 - 52)
		if need_round_away(res_h, res_l, res_first1 - 52)
			res_s++
		;;
	else
		res_s = shl(res_h, res_first1 - 52)
	;;

	/* The res_s++s might have messed everything up */
	if res_s & (1 << 53) != 0
		res_s >= 1
		res_firstbit_e++
		if res_firstbit_e >= 1024
			if res_n
				-> std.flt64frombits(0xfff0000000000000)
			else
				-> std.flt64frombits(0x7ff0000000000000)
			;;
		;;
	;;

	-> std.flt64assem(res_n, res_firstbit_e, res_s)
}

/*
   Add (a << s) to [ h ][ l ], where if s < 0 then a corresponding
   right-shift is used. This is aligned such that if s == 0, then
   the result is [ h ][ l + a ]
 */
const add_shifted = {h : uint64, l : uint64, a : uint64, s : int64
	if s >= 64
		-> (h + shl(a, s - 64), l)
	elif s >= 0
		var new_h = h + shr(a, 64 - s)
		var sa = shl(a, s)
		var new_l = l + sa
		if new_l < l
			new_h++
		;;
		-> (new_h, new_l)
	else
		var new_h = h
		var sa = shr(a, -s)
		var new_l = l + sa
		if new_l < l
			new_h++
		;;
		-> (new_h, new_l)
	;;
}

/* As above, but subtract (a << s) */
const sub_shifted = {h : uint64, l : uint64, a : uint64, s : int64
	if s >= 64
		-> (h - shl(a, s - 64), l)
	elif s >= 0
		var new_h = h - shr(a, 64 - s)
		var sa = shl(a, s)
		var new_l = l - sa
		if sa > l
			new_h--
		;;
		-> (new_h, new_l)
	else
		var new_h = h
		var sa = shr(a, -s)
		var new_l = l - sa
		if sa > l
			new_h--
		;;
		-> (new_h, new_l)
	;;
}

const compare_hl_z = {h : uint64, l : uint64, hl_firstbit_e : int64, hl_lastbit_e : int64, z : uint64, z_firstbit_e : int64, z_lastbit_e : int64
	if hl_firstbit_e > z_firstbit_e
		-> `std.Before
	elif hl_firstbit_e < z_firstbit_e
		-> `std.After
	;;

	var h_k : int64 = (hl_firstbit_e - hl_lastbit_e - 64)
	var z_k : int64 = (z_firstbit_e - z_lastbit_e)
	while h_k >= 0 && z_k >= 0
		var h1 = h & shl(1, h_k) != 0
		var z1 = z & shl(1, z_k) != 0
		if h1 && !z1
			-> `std.Before
		elif !h1 && z1
			-> `std.After
		;;
		h_k--
		z_k--
	;;

	if z_k < 0
		if (h & shr((-1 : uint64), 64 - h_k) != 0) || (l != 0)
			-> `std.Before
		else
			-> `std.Equal
		;;
	;;

	var l_k : int64 = 63
	while l_k >= 0 && z_k >= 0
		var l1 = l & shl(1, l_k) != 0
		var z1 = z & shl(1, z_k) != 0
		if l1 && !z1
			-> `std.Before
		elif !l1 && z1
			-> `std.After
		;;
		l_k--
		z_k--
	;;

	if (z_k < 0) && (l & shr((-1 : uint64), 64 - l_k) != 0)
		-> `std.Before
	elif (l_k < 0) && (z & shr((-1 : uint64), 64 - z_k) != 0)
		-> `std.After
	;;

	-> `std.Equal
}