shithub: mc

ref: 4d79fc86e2b3b90faea794feb00fb6841c0957e7
dir: /lib/math/exp-impl.myr/

View raw version
use std

use "fpmath"
use "util"

/*
    See [Mul16] (6.2.1), [Tan89], and [Tan92]. While the exp and
    expm1 functions are similar enough to be in the same file (Tang's
    algorithms use the same S constants), they are not quite similar
    enough to be in the same function.
 */
pkg math =
	pkglocal const exp32 : (f : flt32 -> flt32)
	pkglocal const exp64 : (f : flt64 -> flt64)

	pkglocal const expm132 : (f : flt32 -> flt32)
	pkglocal const expm164 : (f : flt64 -> flt64)
;;

extern const fma32 : (x : flt32, y : flt32, z : flt32 -> flt32)
extern const fma64 : (x : flt64, y : flt64, z : flt64 -> flt64)
extern const horner_polyu32 : (f : flt32, a : uint32[:] -> flt32)
extern const horner_polyu64 : (f : flt64, a : uint64[:] -> flt64)

type fltdesc(@f, @u, @i) = struct
	explode : (f : @f -> (bool, @i, @u))
	assem : (n : bool, e : @i, s : @u -> @f)
	fma : (x : @f, y : @f, z : @f -> @f)
	horner : (f : @f, a : @u[:] -> @f)
	sgnmask : @u
	tobits : (f : @f -> @u)
	frombits : (u : @u -> @f)
	inf : @u
	nan : @u

	/* For exp */
	thresh_1_min : @u
	thresh_1_max : @u
	thresh_2 : @u
	Ai : @u[:]

	/* For expm1 */
	thresh_tiny : @u
	thresh_huge_neg : @u
	Log3_4 : @u
	Log5_4 : @u
	Bi : @u[:]
	precision : @u

	/* For both */
	nabs : @u
	inv_L : @u
	L1 : @u
	L2 : @u
	S : (@u, @u)[32]

;;

const desc32 : fltdesc(flt32, uint32, int32) = [
	.explode = std.flt32explode,
	.assem = std.flt32assem,
	.fma = fma32,
	.horner = horner_polyu32,
	.sgnmask = (1 << 31),
	.tobits = std.flt32bits,
	.frombits = std.flt32frombits,
	.inf = 0x7f800000,
	.nan = 0x7fc00000,
	.thresh_1_min = 0xc2cff1b4, /* ln(2^(-127 - 23)) ~= -103.9720770839917 */
	.thresh_1_max = 0x42b17218, /* ln([2 - 2^(-24)] * 2^127) ~= 88.72283905206 */
	.inv_L = 0x4238aa3b, /* L = log(2) / 32, this is 1/L in working precision */
	.L1 = 0x3cb17200, /* L1 and L2 sum to give L in high precision, */
	.L2 = 0x333fbe8e, /* and L1 has some trailing zeroes. */
	.nabs = 9, /* L1 has 9 trailing zeroes in binary */
	.Ai = [ /* Coefficients for approximating expm1 in a tiny range */
		0x3f000044,
		0x3e2aaaec,
	][:],
	.Log3_4 = 0xbe934b11, /* log(1 - 1/4) < x < log(1 + 1/4) */
	.Log5_4 = 0x3e647fbf, /* triggers special expm1 case */
	.thresh_tiny = 0x33000000, /* similar to thresh_1_{min,max}, but for expm1 */
	.thresh_huge_neg = 0xc18aa122,
	.Bi = [ /* Coefficients for approximating expm1 between log(3/4) and log(5/4) */
		0x3e2aaaaa,
		0x3d2aaaa0,
		0x3c0889ff,
		0x3ab64de5,
		0x394ab327,
	][:],
	.S = [ /* Double-precise expansions of 2^(J/32) */
		(0x3f800000, 0x00000000),
		(0x3f82cd80, 0x35531585),
		(0x3f85aac0, 0x34d9f312),
		(0x3f889800, 0x35e8092e),
		(0x3f8b95c0, 0x3471f546),
		(0x3f8ea400, 0x36e62d17),
		(0x3f91c3c0, 0x361b9d59),
		(0x3f94f4c0, 0x36bea3fc),
		(0x3f9837c0, 0x36c14637),
		(0x3f9b8d00, 0x36e6e755),
		(0x3f9ef500, 0x36c98247),
		(0x3fa27040, 0x34c0c312),
		(0x3fa5fec0, 0x36354d8b),
		(0x3fa9a140, 0x3655a754),
		(0x3fad5800, 0x36fba90b),
		(0x3fb123c0, 0x36d6074b),
		(0x3fb504c0, 0x36cccfe7),
		(0x3fb8fb80, 0x36bd1d8c),
		(0x3fbd0880, 0x368e7d60),
		(0x3fc12c40, 0x35cca667),
		(0x3fc56700, 0x36a84554),
		(0x3fc9b980, 0x36f619b9),
		(0x3fce2480, 0x35c151f8),
		(0x3fd2a800, 0x366c8f89),
		(0x3fd744c0, 0x36f32b5a),
		(0x3fdbfb80, 0x36de5f6c),
		(0x3fe0ccc0, 0x36776155),
		(0x3fe5b900, 0x355cef90),
		(0x3feac0c0, 0x355cfba5),
		(0x3fefe480, 0x36e66f73),
		(0x3ff52540, 0x36f45492),
		(0x3ffa8380, 0x36cb6dc9),
	],
	.precision = 24, /* threshold to prevent underflow in a S[high] + 2^-m */
]

const desc64 : fltdesc(flt64, uint64, int64) = [
	.explode = std.flt64explode,
	.assem = std.flt64assem,
	.fma = fma64,
	.horner = horner_polyu64,
	.sgnmask = (1 << 63),
	.tobits = std.flt64bits,
	.frombits = std.flt64frombits,
	.inf = 0x7ff0000000000000,
	.nan = 0x7ff8000000000000,
	.thresh_1_min = 0xc0874910d52d3052, /* ln(2^(-1023 - 52)) ~= -745.1332191019 */
	.thresh_1_max = 0x40862e42fefa39ef, /* ln([2 - 2^(-53)] * 2^1023) ~= 709.78271289 */
	.inv_L = 0x40471547652b82fe, /* below as in exp32 */
	.L1 = 0x3f962e42fef00000,
	.L2 = 0x3d8473de6af278ed,
	.nabs = 20,
	.Ai = [
		0x3fe0000000000000,
		0x3fc5555555548f7c,
		0x3fa5555555545d4e,
		0x3f811115b7aa905e,
		0x3f56c1728d739765,
	][:],
	.Log3_4 = 0xbfd269621134db93,
	.Log5_4 = 0x3fcc8ff7c79a9a22,
	.thresh_tiny = 0x3c90000000000000,
	.thresh_huge_neg = 0xc042b708872320e1,
	.Bi = [
		0x3fc5555555555549,
		0x3fa55555555554b6,
		0x3f8111111111a9f3,
		0x3f56c16c16ce14c6,
		0x3f2a01a01159dd2d,
		0x3efa019f635825c4,
		0x3ec71e14bfe3db59,
		0x3e928295484734ea,
		0x3e5a2836aa646b96,
	][:],
	.S = [
		(0x3ff0000000000000, 0x0000000000000000),
		(0x3ff059b0d3158540, 0x3d0a1d73e2a475b4),
		(0x3ff0b5586cf98900, 0x3ceec5317256e308),
		(0x3ff11301d0125b40, 0x3cf0a4ebbf1aed93),
		(0x3ff172b83c7d5140, 0x3d0d6e6fbe462876),
		(0x3ff1d4873168b980, 0x3d053c02dc0144c8),
		(0x3ff2387a6e756200, 0x3d0c3360fd6d8e0b),
		(0x3ff29e9df51fdec0, 0x3d009612e8afad12),
		(0x3ff306fe0a31b700, 0x3cf52de8d5a46306),
		(0x3ff371a7373aa9c0, 0x3ce54e28aa05e8a9),
		(0x3ff3dea64c123400, 0x3d011ada0911f09f),
		(0x3ff44e0860618900, 0x3d068189b7a04ef8),
		(0x3ff4bfdad5362a00, 0x3d038ea1cbd7f621),
		(0x3ff5342b569d4f80, 0x3cbdf0a83c49d86a),
		(0x3ff5ab07dd485400, 0x3d04ac64980a8c8f),
		(0x3ff6247eb03a5580, 0x3cd2c7c3e81bf4b7),
		(0x3ff6a09e667f3bc0, 0x3ce921165f626cdd),
		(0x3ff71f75e8ec5f40, 0x3d09ee91b8797785),
		(0x3ff7a11473eb0180, 0x3cdb5f54408fdb37),
		(0x3ff82589994cce00, 0x3cf28acf88afab35),
		(0x3ff8ace5422aa0c0, 0x3cfb5ba7c55a192d),
		(0x3ff93737b0cdc5c0, 0x3d027a280e1f92a0),
		(0x3ff9c49182a3f080, 0x3cf01c7c46b071f3),
		(0x3ffa5503b23e2540, 0x3cfc8b424491caf8),
		(0x3ffae89f995ad380, 0x3d06af439a68bb99),
		(0x3ffb7f76f2fb5e40, 0x3cdbaa9ec206ad4f),
		(0x3ffc199bdd855280, 0x3cfc2220cb12a092),
		(0x3ffcb720dcef9040, 0x3d048a81e5e8f4a5),
		(0x3ffd5818dcfba480, 0x3cdc976816bad9b8),
		(0x3ffdfc97337b9b40, 0x3cfeb968cac39ed3),
		(0x3ffea4afa2a490c0, 0x3cf9858f73a18f5e),
		(0x3fff50765b6e4540, 0x3c99d3e12dd8a18b),
	],
	.precision = 53,
]

const exp32 = {f : flt32
	-> expgen(f, desc32)
}

const exp64 = {f : flt64
	-> expgen(f, desc64)
}

generic expgen = {f : @f, d : fltdesc(@f, @u, @i) :: numeric,floating,std.equatable @f, numeric,integral @u, numeric,integral @i, roundable @f -> @i
	var b = d.tobits(f)
	var n, e, s
	(n, e, s) = d.explode(f)

	/*
	   Detect if exp(f) would round to outside representability.
	   We don't do bias adjustment, so Tang's thresholds can
	   be tightened.
	 */
	if !n && b >= d.thresh_1_max
		-> d.frombits(d.inf)
	elif n && b > d.thresh_1_min
		-> (0.0 : @f)
	;;

	/* Detect if exp(f) is close to 1. */
	if (b & ~d.sgnmask) <= d.thresh_2
		-> (1.0 : @f)
	;;

	/* Argument reduction to [ -ln(2)/64, ln(2)/64 ] */
	var inv_L = d.frombits(d.inv_L)

	var N = rn(f * inv_L)
	var N2  = N % (32 : @i)
	if N2 < 0
		N2 += (32 : @i)
	;;
	var N1 = N - N2

	var R1 : @f, R2 : @f
	var Nf = (N : @f)

	/*
	   There are few enough significant bits that these are all
	   exact, without need for an FMA. R1 + R2 will approximate
	   (very well) f reduced into [ -ln(2)/64, ln(2)/64 ]
	 */
	if std.abs(N) >= (1 << d.nabs)
		R1 = (f - (N1 : @f) * d.frombits(d.L1)) - ((N2 : @f) * d.frombits(d.L1))
	else
		R1 = f - (N : @f) * d.frombits(d.L1)
	;;
	R2 = -1.0 * (N : @f) * d.frombits(d.L2)

	var M = (N1 : @i) / 32
	var J = (N2 : std.size)

	/* Compute a polynomial approximation of exp1m */
	var R = R1 + R2
	var Q = R * R * d.horner(R, d.Ai)
	var P = R1 + (R2 + Q)

	/* Expand out from reduced range */
	var Su_hi, Su_lo
	(Su_hi, Su_lo) = d.S[J]
	var S_hi = d.frombits(Su_hi)
	var S_lo = d.frombits(Su_lo)

	var S = S_hi + S_lo
	var unscaled = S_hi + (S_lo + (P * S))
	var nu, eu, su
	(nu, eu, su) = d.explode(unscaled)
	var exp = d.assem(nu, eu + M, su)
	if (d.tobits(exp) == 0)
		/* Make sure we don't quite return 0 */
		-> d.frombits(1)
	;;

	-> exp
}

const expm132 = {f : flt32
	-> expm1gen(f, desc32)
}

const expm164 = {f : flt64
	-> expm1gen(f, desc64)
}

generic expm1gen = {f : @f, d : fltdesc(@f, @u, @i) :: \
    numeric,floating,std.equatable @f,
    numeric,integral @u,
    numeric,integral @i,
    roundable @f -> @i
	var b = d.tobits(f)
	var n, e, s
	(n, e, s) = d.explode(f)

	/* Special cases: +/- 0, inf, NaN, tiny, and huge */
	if (b & ~d.sgnmask == 0)
		-> f
	elif n && (b & ~d.sgnmask == d.inf)
		-> (-1.0 : @f)
	elif (b & ~d.sgnmask == d.inf)
		-> f
	elif std.isnan(f)
		-> d.frombits(d.nan)
	elif (b & ~d.sgnmask) <= d.thresh_tiny
		var two_to_large = d.assem(false, 100, 0)
		var two_to_small = d.assem(false, -100, 0)
		var abs_f = d.assem(false, e, s)
		-> (two_to_large * f + abs_f) * two_to_small
	elif !n && b >= d.thresh_1_max /* exp(x) = oo <=> expm1(x) = oo, as it turns out */
		-> d.frombits(d.inf)
	elif n && b >= d.thresh_huge_neg
		-> (-1.0 : @f)
	;;

	if ((n && b < d.Log3_4) || (!n && b < d.Log5_4))
		/* Procedure 2 */

		/* compute x^2  / 2 with extra precision */
		var u = round(f, d)
		var v = f - u
		var y = u * u * (0.5 : @f)
		var z = v * (f + u) * (0.5 : @f)
		var q = f * f * f * d.horner(f, d.Bi)

		var yn, ye, ys
		(yn, ye, ys) = d.explode(y)
		if (ye >= -7)
			-> (u + y) + (q + (v  + z))
		else
			-> f + (y + (q + z))
		;;
	;;

	/* Procedure 1 */
	var inv_L = d.frombits(d.inv_L)

	var N = rn(f * inv_L)
	var N2 = N % (32 : @i)
	if N2 < 0
		N2 += (32 : @i)
	;;
	var N1 = N - N2

	var R1 : @f, R2 : @f

	/*
	   As in the exp case, R1 + R2 very well approximates f
	   reduced into [ -ln(2)/64, ln(2)/64 ]
	 */
	if std.abs(N) >= (1 << d.nabs)
		R1 = (f - (N1 : @f) * d.frombits(d.L1)) - ((N2 : @f) * d.frombits(d.L1))
	else
		R1 = f - (N : @f) * d.frombits(d.L1)
	;;
	R2 = -1.0 * (N : @f) * d.frombits(d.L2)

	var M = (N1 : @i) / 32
	var J = (N2 : std.size)

	/* Compute a polynomial approximation of exp1m */
	var R = R1 + R2
	var Q = R * R * d.horner(R, d.Ai)
	var P = R1 + (R2 + Q)

	/* Expand out */
	var Su_hi, Su_lo
	(Su_hi, Su_lo) = d.S[J]
	var S_hi = d.frombits(Su_hi)
	var S_lo = d.frombits(Su_lo)
	var S = S_hi + S_lo

	/*
	   Note that [Tan92] includes cases to avoid tripping the
	   underflow flag when not technically required. We currently
	   do not care about such flags, so that logic is omitted.
	 */
	if M >= d.precision
		-> scale2(S_hi + (S * P + (S_lo - scale2(1.0, -M))), M)
	elif M <= -8
		-> scale2(S_hi + (S * P + S_lo), M) - (1.0 : @f)
	;;

	-> scale2((S_hi - scale2(1.0, -M)) + (S_hi * P + S_lo * (P + (1.0 : @f))), M)
}

generic round = {f : @f, d : fltdesc(@f, @u, @i) :: numeric,floating,std.equatable @f, numeric,integral @u, numeric,integral @i, roundable @f -> @i
	var b = d.tobits(f)
	var n, e, s
	(n, e, s) = d.explode(f)
	var mask = ~((1 << d.nabs) - 1)
	if need_round_away(0, ((s & mask) : uint64), (d.nabs : int64) + 1)
		-> d.assem(n, e, 1 + s & mask)
	;;
	-> d.assem(n, e, s & mask)
}