ref: debd13efe12eaca072229e9859aaa6df8708a09f
dir: /lib/math/fma-impl.myr/
use std pkg math = pkglocal const fma32 : (x : flt32, y : flt32, z : flt32 -> flt32) pkglocal const fma64 : (x : flt64, y : flt64, z : flt64 -> flt64) ;; const exp_mask32 : uint32 = 0xff << 23 const exp_mask64 : uint64 = 0x7ff << 52 pkglocal const fma32 = {x : flt32, y : flt32, z : flt32 var xn, yn (xn, _, _) = std.flt32explode(x) (yn, _, _) = std.flt32explode(y) var xd : flt64 = flt64fromflt32(x) var yd : flt64 = flt64fromflt32(y) var zd : flt64 = flt64fromflt32(z) var prod : flt64 = xd * yd var pn, pe, ps (pn, pe, ps) = std.flt64explode(prod) if pe == -1023 pe = -1022 ;; if pn != (xn != yn) /* In case of NaNs, sign might not have been preserved */ pn = (xn != yn) prod = std.flt64assem(pn, pe, ps) ;; var r : flt64 = prod + zd var rn, re, rs (rn, re, rs) = std.flt64explode(r) /* At this point, r is probably the correct answer. The only issue is the rounding. Ex 1: If x*y > 0 and z is a tiny, negative number, then adding z probably does no rounding. However, if truncating to 23 bits of precision would cause round-to-even, and that round would be upwards, then we need to remember those trailing bits of z and cancel the rounding. Ex 2: If x, y, z > 0, and z is small, with last bit in flt64 | last bit in flt32 v v x * y = ...............101011..11 z = 10000..., then x * y + z will be rounded to ...............101100..00, and then as a flt32 it will become ...............110, Even though, looking at the original bits, it doesn't "deserve" the final rounding. These can only happen if r is non-inf, non-NaN, and the lower 29 bits correspond to "exactly halfway". */ if re == 1024 || rs & 0x1fffffff != 0x10000000 -> flt32fromflt64(r) ;; /* At this point, a rounding is about to happen. We need to know what direction that rounding is, so that we can tell if it's wrong. +1 means "away from 0", -1 means "towards 0". */ var zn, ze, zs (zn, ze, zs) = std.flt64explode(zd) var round_direction = 0 if rs & 0x20000000 == 0 round_direction = -1 else round_direction = 1 ;; var smaller, larger, smaller_e, larger_e if pe > ze || (pe == ze && ps > zs) (smaller, larger, smaller_e, larger_e) = (zs, ps, ze, pe) else (smaller, larger, smaller_e, larger_e) = (ps, zs, pe, ze) ;; var mask = shr((-1 : uint64), 64 - std.min(64, larger_e - smaller_e)) var prevent_rounding = false if (round_direction > 0 && pn != zn) || (round_direction < 0 && pn == zn) /* The prospective rounding disagrees with the signage. We are potentially in the case of Ex 1. Look at the bits (of the smaller flt64) that are outside the range of r. If there are any such bits, we need to cancel the rounding. We certainly need to consider bits very far to the right, but there's an awkwardness concerning the bit just outside the flt64 range: it governed round-to-even, so it might have had an effect. We only care about bits which did not have an effect. Therefore, we perform the subtraction using only the bits from smaller that lie in larger's range, then check whether the result is susceptible to round-to-even. (Since we only care about the last bit, and the base is 2, subtraction or addition are equally useful.) */ if (larger ^ shr(smaller, larger_e - smaller_e)) & 0x1 != 0 mask >>= 1 ;; prevent_rounding = smaller & mask != 0 else /* The prospective rounding agrees with the signage. We are potentially in the case of Ex 2. We just need to check if r was obtained by rounding in the addition step. In this case, we still check the smaller/larger, and we only care about round-to-even. Any rounding that happened previously is enough reason to disqualify this next rounding. */ prevent_rounding = (larger ^ shr(smaller, larger_e - smaller_e)) & 0x1 != 0 ;; if prevent_rounding if round_direction > 0 rs-- else rs++ ;; ;; -> flt32fromflt64(std.flt64assem(rn, re, rs)) } const flt64fromflt32 = {f : flt32 var n, e, s (n, e, s) = std.flt32explode(f) var xs : uint64 = (s : uint64) var xe : int64 = (e : int64) if e == 128 -> std.flt64assem(n, 1024, xs) elif e == -127 /* All subnormals in single precision (except 0.0s) can be upgraded to double precision, since the exponent range is so much wider. */ var first1 = find_first1_64(xs, 23) if first1 < 0 -> std.flt64assem(n, -1023, 0) ;; xs = xs << (52 - (first1 : uint64)) xe = -126 - (23 - first1) -> std.flt64assem(n, xe, xs) ;; -> std.flt64assem(n, xe, xs << (52 - 23)) } const flt32fromflt64 = {f : flt64 var n : bool, e : int64, s : uint64 (n, e, s) = std.flt64explode(f) var ts : uint32 var te : int32 = (e : int32) if e >= 128 if e == 1023 && s != 0 /* NaN */ -> std.flt32assem(n, 128, 1) else /* infinity */ -> std.flt32assem(n, 128, 0) ;; ;; if e >= -127 /* normal */ ts = ((s >> (52 - 23)) : uint32) if need_round_away(0, s, 52 - 23) ts++ if ts & (1 << 24) != 0 ts >>= 1 te++ ;; ;; if te >= -126 -> std.flt32assem(n, te, ts) ;; ;; /* subnormal already, will have to go to 0 */ if e == -1023 -> std.flt32assem(n, -127, 0) ;; /* subnormal (at least, it will be) */ te = -127 var shift : int64 = (52 - 23) + (-126 - e) var ts1 = shr(s, shift) ts = (ts1 : uint32) if need_round_away(0, s, shift) ts++ ;; -> std.flt32assem(n, te, ts) } pkglocal const fma64 = {x : flt64, y : flt64, z : flt64 var xn : bool, yn : bool, zn : bool var xe : int64, ye : int64, ze : int64 var xs : uint64, ys : uint64, zs : uint64 var xb : uint64 = std.flt64bits(x) var yb : uint64 = std.flt64bits(y) var zb : uint64 = std.flt64bits(z) /* check for both NaNs and infinities */ if xb & exp_mask64 == exp_mask64 || \ yb & exp_mask64 == exp_mask64 -> x * y + z elif z == 0.0 || z == -0.0 || x * y == 0.0 || x * y == -0.0 -> x * y + z elif zb & exp_mask64 == exp_mask64 -> z ;; (xn, xe, xs) = std.flt64explode(x) (yn, ye, ys) = std.flt64explode(y) (zn, ze, zs) = std.flt64explode(z) if xe == -1023 xe = -1022 ;; if ye == -1023 ye = -1022 ;; if ze == -1023 ze = -1022 ;; /* Keep product in high/low uint64s */ var xs_h : uint64 = xs >> 32 var ys_h : uint64 = ys >> 32 var xs_l : uint64 = xs & 0xffffffff var ys_l : uint64 = ys & 0xffffffff var t_l : uint64 = xs_l * ys_l var t_m : uint64 = xs_l * ys_h + xs_h * ys_l var t_h : uint64 = xs_h * ys_h var prod_l : uint64 = t_l + (t_m << 32) var prod_h : uint64 = t_h + (t_m >> 32) if t_l > prod_l prod_h++ ;; var prod_n = xn != yn var prod_lastbit_e = (xe - 52) + (ye - 52) var prod_first1 = find_first1_64_hl(prod_h, prod_l, 105) var prod_firstbit_e = prod_lastbit_e + prod_first1 var z_firstbit_e = ze var z_lastbit_e = ze - 52 var z_first1 = 52 /* subnormals could throw firstbit_e calculations out of whack */ if (zb & exp_mask64 == 0) z_first1 = find_first1_64(zs, z_first1) z_firstbit_e = z_lastbit_e + z_first1 ;; var res_n var res_h = 0 var res_l = 0 var res_first1 var res_lastbit_e var res_firstbit_e if prod_n == zn res_n = prod_n /* Align prod and z so that the top bit of the result is either 53 or 54, then add. */ if prod_firstbit_e >= z_firstbit_e /* [ prod_h ][ prod_l ] [ z... */ res_lastbit_e = prod_lastbit_e (res_h, res_l) = (prod_h, prod_l) (res_h, res_l) = add_shifted(res_h, res_l, zs, z_lastbit_e - prod_lastbit_e) else /* [ prod_h ][ prod_l ] [ z... */ res_lastbit_e = z_lastbit_e - 64 res_h = zs res_l = 0 if prod_lastbit_e >= res_lastbit_e + 64 /* In this situation, prod must be extremely subnormal */ res_h += shl(prod_l, prod_lastbit_e - res_lastbit_e - 64) elif prod_lastbit_e >= res_lastbit_e res_h += shl(prod_h, prod_lastbit_e - res_lastbit_e) res_h += shr(prod_l, res_lastbit_e + 64 - prod_lastbit_e) res_l += shl(prod_l, prod_lastbit_e - res_lastbit_e) elif prod_lastbit_e + 64 >= res_lastbit_e res_h += shr(prod_h, res_lastbit_e - prod_lastbit_e) var l1 = shl(prod_h, prod_lastbit_e + 64 - res_lastbit_e) var l2 = shr(prod_l, res_lastbit_e - prod_lastbit_e) res_l = l1 + l2 if res_l < l1 res_h++ ;; elif prod_lastbit_e + 128 >= res_lastbit_e res_l += shr(prod_h, res_lastbit_e - prod_lastbit_e - 64) ;; ;; else match compare_hl_z(prod_h, prod_l, prod_firstbit_e, prod_lastbit_e, zs, z_firstbit_e, z_lastbit_e) | `std.Equal: -> 0.0 | `std.Before: /* prod > z */ res_n = prod_n res_lastbit_e = prod_lastbit_e (res_h, res_l) = sub_shifted(prod_h, prod_l, zs, z_lastbit_e - prod_lastbit_e) | `std.After: /* z > prod */ res_n = zn res_lastbit_e = z_lastbit_e - 64 (res_h, res_l) = sub_shifted(zs, 0, prod_h, prod_lastbit_e + 64 - (z_lastbit_e - 64)) (res_h, res_l) = sub_shifted(res_h, res_l, prod_l, prod_lastbit_e - (z_lastbit_e - 64)) ;; ;; res_first1 = 64 + find_first1_64(res_h, 55) if res_first1 == 63 res_first1 = find_first1_64(res_l, 63) ;; res_firstbit_e = res_first1 + res_lastbit_e /* Finally, res_h and res_l are the high and low bits of the result. They now need to be assembled into a flt64. Subnormals and infinities could be a problem. */ var res_s = 0 if res_firstbit_e <= -1023 /* Subnormal case */ if res_lastbit_e + 128 < 12 - 1022 res_s = shr(res_h, 12 - 1022 - (res_lastbit_e + 128)) res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64)) elif res_lastbit_e + 64 < 12 - 1022 res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022)) res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64)) else res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022)) res_s |= shl(res_l, -12 + (res_lastbit_e + 64) - (-1022)) ;; if need_round_away(res_h, res_l, res_first1 + (-1074 - res_firstbit_e)) res_s++ ;; /* No need for exponents, they are all zero */ var res = res_s if res_n res |= (1 << 63) ;; -> std.flt64frombits(res) ;; if res_firstbit_e >= 1024 /* Infinity case */ if res_n -> std.flt64frombits(0xfff0000000000000) else -> std.flt64frombits(0x7ff0000000000000) ;; ;; if res_first1 - 52 >= 64 res_s = shr(res_h, (res_first1 : int64) - 64 - 52) if need_round_away(res_h, res_l, res_first1 - 52) res_s++ ;; elif res_first1 - 52 >= 0 res_s = shl(res_h, 64 - (res_first1 - 52)) res_s |= shr(res_l, res_first1 - 52) if need_round_away(res_h, res_l, res_first1 - 52) res_s++ ;; else res_s = shl(res_h, res_first1 - 52) ;; /* The res_s++s might have messed everything up */ if res_s & (1 << 53) != 0 res_s >= 1 res_firstbit_e++ if res_firstbit_e >= 1024 if res_n -> std.flt64frombits(0xfff0000000000000) else -> std.flt64frombits(0x7ff0000000000000) ;; ;; ;; -> std.flt64assem(res_n, res_firstbit_e, res_s) } /* >> and <<, but without wrapping when the shift is >= 64 */ const shr : (u : uint64, s : int64 -> uint64) = {u : uint64, s : int64 if (s : uint64) >= 64 -> 0 else -> u >> (s : uint64) ;; } const shl : (u : uint64, s : int64 -> uint64) = {u : uint64, s : int64 if (s : uint64) >= 64 -> 0 else -> u << (s : uint64) ;; } /* Add (a << s) to [ h ][ l ], where if s < 0 then a corresponding right-shift is used. This is aligned such that if s == 0, then the result is [ h ][ l + a ] */ const add_shifted = {h : uint64, l : uint64, a : uint64, s : int64 if s >= 64 -> (h + shl(a, s - 64), l) elif s >= 0 var new_h = h + shr(a, 64 - s) var sa = shl(a, s) var new_l = l + sa if new_l < l new_h++ ;; -> (new_h, new_l) else var new_h = h var sa = shr(a, -s) var new_l = l + sa if new_l < l new_h++ ;; -> (new_h, new_l) ;; } /* As above, but subtract (a << s) */ const sub_shifted = {h : uint64, l : uint64, a : uint64, s : int64 if s >= 64 -> (h - shl(a, s - 64), l) elif s >= 0 var new_h = h - shr(a, 64 - s) var sa = shl(a, s) var new_l = l - sa if sa > l new_h-- ;; -> (new_h, new_l) else var new_h = h var sa = shr(a, -s) var new_l = l - sa if sa > l new_h-- ;; -> (new_h, new_l) ;; } const compare_hl_z = {h : uint64, l : uint64, hl_firstbit_e : int64, hl_lastbit_e : int64, z : uint64, z_firstbit_e : int64, z_lastbit_e : int64 if hl_firstbit_e > z_firstbit_e -> `std.Before elif hl_firstbit_e < z_firstbit_e -> `std.After ;; var h_k : int64 = (hl_firstbit_e - hl_lastbit_e - 64) var z_k : int64 = (z_firstbit_e - z_lastbit_e) while h_k >= 0 && z_k >= 0 var h1 = h & shl(1, h_k) != 0 var z1 = z & shl(1, z_k) != 0 if h1 && !z1 -> `std.Before elif !h1 && z1 -> `std.After ;; h_k-- z_k-- ;; if z_k < 0 if (h & shr((-1 : uint64), 64 - h_k) != 0) || (l != 0) -> `std.Before else -> `std.Equal ;; ;; var l_k : int64 = 63 while l_k >= 0 && z_k >= 0 var l1 = l & shl(1, l_k) != 0 var z1 = z & shl(1, z_k) != 0 if l1 && !z1 -> `std.Before elif !l1 && z1 -> `std.After ;; l_k-- z_k-- ;; if (z_k < 0) && (l & shr((-1 : uint64), 64 - l_k) != 0) -> `std.Before elif (l_k < 0) && (z & shr((-1 : uint64), 64 - z_k) != 0) -> `std.After ;; -> `std.Equal } /* Find the first 1 bit in a bitstring */ const find_first1_64 : (b : uint64, start : int64 -> int64) = {b : uint64, start : int64 for var j = start; j >= 0; --j var m = shl(1, j) if b & m != 0 -> j ;; ;; -> -1 } const find_first1_64_hl = {h, l, start var first1_h = find_first1_64(h, start - 64) if first1_h >= 0 -> first1_h + 64 ;; -> find_first1_64(l, 63) } /* For [ h ][ l ], where bitpos_last is the position of the last bit that was included in the truncated result (l's last bit has position 0), decide whether rounding up/away is needed. This is true if - following bitpos_last is a 1, then a non-zero sequence, or - following bitpos_last is a 1, then a zero sequence, and the round would be to even */ const need_round_away = {h : uint64, l : uint64, bitpos_last : int64 var first_omitted_is_1 = false var nonzero_beyond = false if bitpos_last > 64 first_omitted_is_1 = h & shl(1, bitpos_last - 1 - 64) != 0 nonzero_beyond = nonzero_beyond || h & shr((-1 : uint64), 2 + 64 - (bitpos_last - 64)) != 0 nonzero_beyond = nonzero_beyond || (l != 0) else first_omitted_is_1 = l & shl(1, bitpos_last - 1) != 0 nonzero_beyond = nonzero_beyond || l & shr((-1 : uint64), 1 + 64 - bitpos_last) != 0 ;; if !first_omitted_is_1 -> false ;; if nonzero_beyond -> true ;; var hl_is_odd = false if bitpos_last >= 64 hl_is_odd = h & shl(1, bitpos_last - 64) != 0 else hl_is_odd = l & shl(1, bitpos_last) != 0 ;; -> hl_is_odd }