ref: a4fdef8a867263627510ea20e9b4a758549f7c98
dir: /lib/math/fma-impl.myr/
use std use "util" pkg math = pkglocal const fma32 : (x : flt32, y : flt32, z : flt32 -> flt32) pkglocal const fma64 : (x : flt64, y : flt64, z : flt64 -> flt64) ;; const exp_mask32 : uint32 = 0xff << 23 const exp_mask64 : uint64 = 0x7ff << 52 pkglocal const fma32 = {x : flt32, y : flt32, z : flt32 var xn, yn (xn, _, _) = std.flt32explode(x) (yn, _, _) = std.flt32explode(y) var xd : flt64 = flt64fromflt32(x) var yd : flt64 = flt64fromflt32(y) var zd : flt64 = flt64fromflt32(z) var prod : flt64 = xd * yd var pn, pe, ps (pn, pe, ps) = std.flt64explode(prod) if pe == -1023 pe = -1022 ;; if pn != (xn != yn) /* In case of NaNs, sign might not have been preserved */ pn = (xn != yn) prod = std.flt64assem(pn, pe, ps) ;; var r : flt64 = prod + zd var rn, re, rs (rn, re, rs) = std.flt64explode(r) /* At this point, r is probably the correct answer. The only issue is the rounding. Ex 1: If x*y > 0 and z is a tiny, negative number, then adding z probably does no rounding. However, if truncating to 23 bits of precision would cause round-to-even, and that round would be upwards, then we need to remember those trailing bits of z and cancel the rounding. Ex 2: If x, y, z > 0, and z is small, with last bit in flt64 | last bit in flt32 v v x * y = ...............101011..11 z = 10000..., then x * y + z will be rounded to ...............101100..00, and then as a flt32 it will become ...............110, Even though, looking at the original bits, it doesn't "deserve" the final rounding. These can only happen if r is non-inf, non-NaN, and the lower 29 bits correspond to "exactly halfway". */ if re == 1024 || rs & 0x1fffffff != 0x10000000 -> flt32fromflt64(r) ;; /* At this point, a rounding is about to happen. We need to know what direction that rounding is, so that we can tell if it's wrong. +1 means "away from 0", -1 means "towards 0". */ var zn, ze, zs (zn, ze, zs) = std.flt64explode(zd) var round_direction = 0 if rs & 0x20000000 == 0 round_direction = -1 else round_direction = 1 ;; var smaller, larger, smaller_e, larger_e if pe > ze || (pe == ze && ps > zs) (smaller, larger, smaller_e, larger_e) = (zs, ps, ze, pe) else (smaller, larger, smaller_e, larger_e) = (ps, zs, pe, ze) ;; var mask = shr((-1 : uint64), 64 - std.min(64, larger_e - smaller_e)) var prevent_rounding = false if (round_direction > 0 && pn != zn) || (round_direction < 0 && pn == zn) /* The prospective rounding disagrees with the signage. We are potentially in the case of Ex 1. Look at the bits (of the smaller flt64) that are outside the range of r. If there are any such bits, we need to cancel the rounding. We certainly need to consider bits very far to the right, but there's an awkwardness concerning the bit just outside the flt64 range: it governed round-to-even, so it might have had an effect. We only care about bits which did not have an effect. Therefore, we perform the subtraction using only the bits from smaller that lie in larger's range, then check whether the result is susceptible to round-to-even. (Since we only care about the last bit, and the base is 2, subtraction or addition are equally useful.) */ if (larger ^ shr(smaller, larger_e - smaller_e)) & 0x1 == 0 prevent_rounding = smaller & mask != 0 ;; else /* The prospective rounding agrees with the signage. We are potentially in the case of Ex 2. We just need to check if r was obtained by rounding in the addition step. In this case, we still check the smaller/larger, and we only care about round-to-even. Any rounding that happened previously is enough reason to disqualify this next rounding. */ prevent_rounding = (larger ^ shr(smaller, larger_e - smaller_e)) & 0x1 != 0 ;; if prevent_rounding if round_direction > 0 rs-- else rs++ ;; ;; -> flt32fromflt64(std.flt64assem(rn, re, rs)) } pkglocal const fma64 = {x : flt64, y : flt64, z : flt64 var xn : bool, yn : bool, zn : bool var xe : int64, ye : int64, ze : int64 var xs : uint64, ys : uint64, zs : uint64 var xb : uint64 = std.flt64bits(x) var yb : uint64 = std.flt64bits(y) var zb : uint64 = std.flt64bits(z) /* check for both NaNs and infinities */ if xb & exp_mask64 == exp_mask64 || \ yb & exp_mask64 == exp_mask64 -> x * y + z elif z == 0.0 || z == -0.0 || x * y == 0.0 || x * y == -0.0 -> x * y + z elif zb & exp_mask64 == exp_mask64 -> z ;; (xn, xe, xs) = std.flt64explode(x) (yn, ye, ys) = std.flt64explode(y) (zn, ze, zs) = std.flt64explode(z) if xe == -1023 xe = -1022 ;; if ye == -1023 ye = -1022 ;; if ze == -1023 ze = -1022 ;; /* Keep product in high/low uint64s */ var xs_h : uint64 = xs >> 32 var ys_h : uint64 = ys >> 32 var xs_l : uint64 = xs & 0xffffffff var ys_l : uint64 = ys & 0xffffffff var t_l : uint64 = xs_l * ys_l var t_m : uint64 = xs_l * ys_h + xs_h * ys_l var t_h : uint64 = xs_h * ys_h var prod_l : uint64 = t_l + (t_m << 32) var prod_h : uint64 = t_h + (t_m >> 32) if t_l > prod_l prod_h++ ;; var prod_n = xn != yn var prod_lastbit_e = (xe - 52) + (ye - 52) var prod_first1 = find_first1_64_hl(prod_h, prod_l, 105) var prod_firstbit_e = prod_lastbit_e + prod_first1 var z_firstbit_e = ze var z_lastbit_e = ze - 52 var z_first1 = 52 /* subnormals could throw firstbit_e calculations out of whack */ if (zb & exp_mask64 == 0) z_first1 = find_first1_64(zs, z_first1) z_firstbit_e = z_lastbit_e + z_first1 ;; var res_n var res_h = 0 var res_l = 0 var res_first1 var res_lastbit_e var res_firstbit_e if prod_n == zn res_n = prod_n /* Align prod and z so that the top bit of the result is either 53 or 54, then add. */ if prod_firstbit_e >= z_firstbit_e /* [ prod_h ][ prod_l ] [ z... */ res_lastbit_e = prod_lastbit_e (res_h, res_l) = (prod_h, prod_l) (res_h, res_l) = add_shifted(res_h, res_l, zs, z_lastbit_e - prod_lastbit_e) else /* [ prod_h ][ prod_l ] [ z... */ res_lastbit_e = z_lastbit_e - 64 res_h = zs res_l = 0 if prod_lastbit_e >= res_lastbit_e + 64 /* In this situation, prod must be extremely subnormal */ res_h += shl(prod_l, prod_lastbit_e - res_lastbit_e - 64) elif prod_lastbit_e >= res_lastbit_e res_h += shl(prod_h, prod_lastbit_e - res_lastbit_e) res_h += shr(prod_l, res_lastbit_e + 64 - prod_lastbit_e) res_l += shl(prod_l, prod_lastbit_e - res_lastbit_e) elif prod_lastbit_e + 64 >= res_lastbit_e res_h += shr(prod_h, res_lastbit_e - prod_lastbit_e) var l1 = shl(prod_h, prod_lastbit_e + 64 - res_lastbit_e) var l2 = shr(prod_l, res_lastbit_e - prod_lastbit_e) res_l = l1 + l2 if res_l < l1 res_h++ ;; elif prod_lastbit_e + 128 >= res_lastbit_e res_l += shr(prod_h, res_lastbit_e - prod_lastbit_e - 64) ;; ;; else match compare_hl_z(prod_h, prod_l, prod_firstbit_e, prod_lastbit_e, zs, z_firstbit_e, z_lastbit_e) | `std.Equal: -> 0.0 | `std.Before: /* prod > z */ res_n = prod_n res_lastbit_e = prod_lastbit_e (res_h, res_l) = sub_shifted(prod_h, prod_l, zs, z_lastbit_e - prod_lastbit_e) | `std.After: /* z > prod */ res_n = zn res_lastbit_e = z_lastbit_e - 64 (res_h, res_l) = sub_shifted(zs, 0, prod_h, prod_lastbit_e + 64 - (z_lastbit_e - 64)) (res_h, res_l) = sub_shifted(res_h, res_l, prod_l, prod_lastbit_e - (z_lastbit_e - 64)) ;; ;; res_first1 = 64 + find_first1_64(res_h, 55) if res_first1 == 63 res_first1 = find_first1_64(res_l, 63) ;; res_firstbit_e = res_first1 + res_lastbit_e /* Finally, res_h and res_l are the high and low bits of the result. They now need to be assembled into a flt64. Subnormals and infinities could be a problem. */ var res_s = 0 if res_firstbit_e <= -1023 /* Subnormal case */ if res_lastbit_e + 128 < 12 - 1022 res_s = shr(res_h, 12 - 1022 - (res_lastbit_e + 128)) res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64)) elif res_lastbit_e + 64 < 12 - 1022 res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022)) res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64)) else res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022)) res_s |= shl(res_l, -12 + (res_lastbit_e + 64) - (-1022)) ;; if need_round_away(res_h, res_l, res_first1 + (-1074 - res_firstbit_e)) res_s++ ;; /* No need for exponents, they are all zero */ var res = res_s if res_n res |= (1 << 63) ;; -> std.flt64frombits(res) ;; if res_firstbit_e >= 1024 /* Infinity case */ if res_n -> std.flt64frombits(0xfff0000000000000) else -> std.flt64frombits(0x7ff0000000000000) ;; ;; if res_first1 - 52 >= 64 res_s = shr(res_h, (res_first1 : int64) - 64 - 52) if need_round_away(res_h, res_l, res_first1 - 52) res_s++ ;; elif res_first1 - 52 >= 0 res_s = shl(res_h, 64 - (res_first1 - 52)) res_s |= shr(res_l, res_first1 - 52) if need_round_away(res_h, res_l, res_first1 - 52) res_s++ ;; else res_s = shl(res_h, res_first1 - 52) ;; /* The res_s++s might have messed everything up */ if res_s & (1 << 53) != 0 res_s >= 1 res_firstbit_e++ if res_firstbit_e >= 1024 if res_n -> std.flt64frombits(0xfff0000000000000) else -> std.flt64frombits(0x7ff0000000000000) ;; ;; ;; -> std.flt64assem(res_n, res_firstbit_e, res_s) } /* Add (a << s) to [ h ][ l ], where if s < 0 then a corresponding right-shift is used. This is aligned such that if s == 0, then the result is [ h ][ l + a ] */ const add_shifted = {h : uint64, l : uint64, a : uint64, s : int64 if s >= 64 -> (h + shl(a, s - 64), l) elif s >= 0 var new_h = h + shr(a, 64 - s) var sa = shl(a, s) var new_l = l + sa if new_l < l new_h++ ;; -> (new_h, new_l) else var new_h = h var sa = shr(a, -s) var new_l = l + sa if new_l < l new_h++ ;; -> (new_h, new_l) ;; } /* As above, but subtract (a << s) */ const sub_shifted = {h : uint64, l : uint64, a : uint64, s : int64 if s >= 64 -> (h - shl(a, s - 64), l) elif s >= 0 var new_h = h - shr(a, 64 - s) var sa = shl(a, s) var new_l = l - sa if sa > l new_h-- ;; -> (new_h, new_l) else var new_h = h var sa = shr(a, -s) var new_l = l - sa if sa > l new_h-- ;; -> (new_h, new_l) ;; } const compare_hl_z = {h : uint64, l : uint64, hl_firstbit_e : int64, hl_lastbit_e : int64, z : uint64, z_firstbit_e : int64, z_lastbit_e : int64 if hl_firstbit_e > z_firstbit_e -> `std.Before elif hl_firstbit_e < z_firstbit_e -> `std.After ;; var h_k : int64 = (hl_firstbit_e - hl_lastbit_e - 64) var z_k : int64 = (z_firstbit_e - z_lastbit_e) while h_k >= 0 && z_k >= 0 var h1 = h & shl(1, h_k) != 0 var z1 = z & shl(1, z_k) != 0 if h1 && !z1 -> `std.Before elif !h1 && z1 -> `std.After ;; h_k-- z_k-- ;; if z_k < 0 if (h & shr((-1 : uint64), 64 - h_k) != 0) || (l != 0) -> `std.Before else -> `std.Equal ;; ;; var l_k : int64 = 63 while l_k >= 0 && z_k >= 0 var l1 = l & shl(1, l_k) != 0 var z1 = z & shl(1, z_k) != 0 if l1 && !z1 -> `std.Before elif !l1 && z1 -> `std.After ;; l_k-- z_k-- ;; if (z_k < 0) && (l & shr((-1 : uint64), 64 - l_k) != 0) -> `std.Before elif (l_k < 0) && (z & shr((-1 : uint64), 64 - z_k) != 0) -> `std.After ;; -> `std.Equal }