Skip to content

Commit 91cbba5

Browse files
Improve calculation of the scale parameter for the uniform float distribution.
Closes rust-randomgh-1299
1 parent 0f3eced commit 91cbba5

File tree

2 files changed

+379
-13
lines changed

2 files changed

+379
-13
lines changed

src/distributions/uniform.rs

Lines changed: 160 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -826,12 +826,109 @@ pub struct UniformFloat<X> {
826826
scale: X,
827827
}
828828

829+
trait Summable<T> {
830+
fn compensated_sum(&self) -> T;
831+
}
832+
833+
trait ScaleComputable<T> {
834+
fn compute_scale(low: T, high: T) -> T;
835+
}
836+
829837
macro_rules! uniform_float_impl {
830838
($ty:ty, $uty:ident, $f_scalar:ident, $u_scalar:ident, $bits_to_discard:expr) => {
831839
impl SampleUniform for $ty {
832840
type Sampler = UniformFloat<$ty>;
833841
}
834842

843+
impl Summable<$ty> for &[$ty] {
844+
fn compensated_sum(&self) -> $ty {
845+
// Kahan compensated sum
846+
let mut sum = <$ty>::splat(0.0);
847+
let mut c = <$ty>::splat(0.0);
848+
for val in *self {
849+
let y = val - c;
850+
let t = sum + y;
851+
c = (t - sum) - y;
852+
sum = t;
853+
}
854+
sum
855+
}
856+
}
857+
858+
impl ScaleComputable<$ty> for $ty {
859+
fn compute_scale(low: $ty, high: $ty) -> $ty {
860+
let eps = <$ty>::splat($f_scalar::EPSILON);
861+
862+
// `max_rand` is 1.0 - eps. This is actually the second largest
863+
// float less than 1.0, because the spacing of the floats in the
864+
// interval [0.5, 1.0) is `eps/2`.
865+
let max_rand = <$ty>::splat(
866+
(::core::$u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0,
867+
);
868+
869+
// `delta_high` is half the distance from `high` to the largest
870+
// float that is less than `high`. If `high` is subnormal or 0,
871+
// then `delta_high` will be 0. Why this is needed is explained
872+
// below.
873+
let delta_high = <$ty>::splat(0.5) * (high - high.utils_next_down());
874+
875+
// We want `scale * max_rand + low < high`. Let `high_1` be the
876+
// (hypothetical) float that is the midpoint between `high` and
877+
// the largest float less than `high`. The midpoint is used
878+
// because any float calculation that would result in a value in
879+
// `(high_1, high)` would be rounded to `high`. The ideal
880+
// condition for upper bound of `scale` is then
881+
// scale * max_rand + low = high_1`
882+
// or
883+
// scale = (high_1 - low)/max_rand
884+
//
885+
// Write `high_1 = high - delta_high`, `max_rand = 1 - eps`,
886+
// and approximate `1/(1 - eps)` as `(1 + eps)`. Then we have
887+
//
888+
// scale = (high - delta_high - low)*(1 + eps)
889+
// = high - low + eps*high - eps*low - delta_high
890+
//
891+
// (The extremely small term `-delta_high*eps` has been ignored.)
892+
// The following uses Kahan's compensated summation to compute `scale`
893+
// from those terms.
894+
let terms: &[$ty] = &[high, -low, eps * high, -eps * low, -delta_high];
895+
let mut scale = terms.compensated_sum();
896+
897+
// Empirical tests show that `scale` is generally within 1 or 2 ULPs
898+
// of the "ideal" scale. Next we adjust `scale`, if necessary, to
899+
// the ideal value.
900+
901+
// Check that `scale * max_rand + low` is less than `high`. If it is
902+
// not, repeatedly adjust `scale` down by one ULP until the condition
903+
// is satisfied. Generally this requires 0 or 1 adjustments to `scale`.
904+
// (The original `too_big_mask` is saved so we can use it again below.)
905+
let too_big_mask = (scale * max_rand + low).ge_mask(high);
906+
loop {
907+
let mask = (scale * max_rand + low).ge_mask(high);
908+
if !mask.any() {
909+
break;
910+
}
911+
scale = scale.decrease_masked(mask);
912+
}
913+
// We have ensured that `scale * max_rand + low < high`. Now see if
914+
// we can increase `scale` and still maintain that inequality. We
915+
// only need to do this if `scale` was not initially too big.
916+
let not_too_big_mask = !too_big_mask;
917+
let mut mask = not_too_big_mask;
918+
if mask.any() {
919+
loop {
920+
let next_scale = scale.increase_masked(mask);
921+
mask = (next_scale * max_rand + low).lt_mask(high) & not_too_big_mask;
922+
if !mask.any() {
923+
break;
924+
}
925+
scale = scale.increase_masked(mask);
926+
}
927+
}
928+
scale
929+
}
930+
}
931+
835932
impl UniformSampler for UniformFloat<$ty> {
836933
type X = $ty;
837934

@@ -849,22 +946,12 @@ macro_rules! uniform_float_impl {
849946
if !(low.all_lt(high)) {
850947
return Err(Error::EmptyRange);
851948
}
852-
let max_rand = <$ty>::splat(
853-
(::core::$u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0,
854-
);
855949

856-
let mut scale = high - low;
857-
if !(scale.all_finite()) {
950+
if !((high - low).all_finite()) {
858951
return Err(Error::NonFinite);
859952
}
860953

861-
loop {
862-
let mask = (scale * max_rand + low).ge_mask(high);
863-
if !mask.any() {
864-
break;
865-
}
866-
scale = scale.decrease_masked(mask);
867-
}
954+
let scale = <$ty>::compute_scale(low, high);
868955

869956
debug_assert!(<$ty>::splat(0.0).all_le(scale));
870957

@@ -1430,6 +1517,67 @@ mod tests {
14301517
}
14311518
}
14321519

1520+
macro_rules! compute_scale_tests {
1521+
($($fname:ident: $ty:ident,)*) => {
1522+
$(
1523+
#[test]
1524+
fn $fname() {
1525+
// For each `(low, high)` pair in `v`, compute `scale` and
1526+
// verify that
1527+
// scale * max_rand + low < high
1528+
// and
1529+
// next_up(scale) * max_rand + low >= high
1530+
let eps = $ty::EPSILON;
1531+
let v = [
1532+
(0.0 as $ty, 100.0 as $ty),
1533+
(-0.125 as $ty, 0.0 as $ty),
1534+
(0.0 as $ty, 0.125 as $ty),
1535+
(-1.5 as $ty, -0.0 as $ty),
1536+
(-0.0 as $ty, 1.5 as $ty),
1537+
(-1.0 as $ty, -0.875 as $ty),
1538+
(-1e35 as $ty, -1e25 as $ty),
1539+
(1e-35 as $ty, 1e-25 as $ty),
1540+
(-1e35 as $ty, 1e35 as $ty),
1541+
// Very small intervals--the difference `high - low` is
1542+
// a not-huge multiple of the type's EPSILON.
1543+
(1.0 as $ty - (11.5 as $ty) * eps, 1.0 as $ty - (0.5 as $ty) * eps),
1544+
(1.0 as $ty - (196389.0 as $ty) * eps / (2.0 as $ty), 1.0 as $ty),
1545+
(1.0 as $ty, 1.0 as $ty + (1.0 as $ty) * eps),
1546+
(1.0 as $ty, 1.0 as $ty + (2.0 as $ty) * eps),
1547+
(1.0 as $ty - eps, 1.0 as $ty),
1548+
(1.0 as $ty - eps, 1.0 as $ty + (2.0 as $ty) * eps),
1549+
(-1.0 as $ty, -1.0 as $ty + (2.0 as $ty) * eps),
1550+
(-2.0 as $ty, -2.0 as $ty + (17.0 as $ty) * eps),
1551+
(-11.0 as $ty, -11.0 as $ty + (68.0 as $ty) *eps),
1552+
// Ridiculously small intervals: `low` and `high` are subnormal.
1553+
(-$ty::from_bits(3), $ty::from_bits(8)),
1554+
(-$ty::from_bits(5), -$ty::from_bits(1)),
1555+
// `high - low` is a significant fraction of the type's MAX.
1556+
((0.5 as $ty) * $ty::MIN, (0.25 as $ty) * $ty::MAX),
1557+
((0.25 as $ty) * $ty::MIN, (0.5 as $ty) * $ty::MAX),
1558+
((0.5 as $ty) * $ty::MIN, (0.4999995 as $ty) * $ty::MAX),
1559+
((0.75 as $ty) * $ty::MIN, 0.0 as $ty),
1560+
(0.0 as $ty, (0.75 as $ty) * $ty::MAX),
1561+
];
1562+
let max_rand = 1.0 as $ty - eps;
1563+
for (low, high) in v {
1564+
let scale = <$ty>::compute_scale(low, high);
1565+
assert!(scale > 0.0 as $ty);
1566+
assert!(scale * max_rand + low < high);
1567+
// let next_scale = scale.next_up();
1568+
let next_scale = <$ty>::from_bits(scale.to_bits() + 1);
1569+
assert!(next_scale * max_rand + low >= high);
1570+
}
1571+
}
1572+
)*
1573+
}
1574+
}
1575+
1576+
compute_scale_tests! {
1577+
test_compute_scale_f32: f32,
1578+
test_compute_scale_f64: f64,
1579+
}
1580+
14331581
#[test]
14341582
fn test_float_overflow() {
14351583
assert_eq!(Uniform::try_from(::core::f64::MIN..::core::f64::MAX), Err(Error::NonFinite));

0 commit comments

Comments
 (0)