Skip to content

Commit 550888e

Browse files
mrakhandrewrk
authored andcommitted
std: improve random float generation
1 parent e0a514d commit 550888e

File tree

2 files changed

+171
-9
lines changed

2 files changed

+171
-9
lines changed

lib/std/rand.zig

Lines changed: 119 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ const math = std.math;
1616
const ziggurat = @import("rand/ziggurat.zig");
1717
const maxInt = std.math.maxInt;
1818

19+
const Dilbert = @import("rand/Dilbert.zig");
20+
1921
/// Fast unbiased random numbers.
2022
pub const DefaultPrng = Xoshiro256;
2123

@@ -249,18 +251,51 @@ pub const Random = struct {
249251

250252
/// Return a floating point value evenly distributed in the range [0, 1).
251253
pub fn float(r: Random, comptime T: type) T {
252-
// Generate a uniform value between [1, 2) and scale down to [0, 1).
253-
// Note: The lowest mantissa bit is always set to 0 so we only use half the available range.
254+
// Generate a uniformly random value between for the mantissa.
255+
// Then generate an exponentially biased random value for the exponent.
256+
// Over the previous method, this has the advantage of being able to
257+
// represent every possible value in the available range.
254258
switch (T) {
255259
f32 => {
256-
const s = r.int(u32);
257-
const repr = (0x7f << 23) | (s >> 9);
258-
return @bitCast(f32, repr) - 1.0;
260+
// Use 23 random bits for the mantissa, and the rest for the exponent.
261+
// If all 41 bits are zero, generate additional random bits, until a
262+
// set bit is found, or 126 bits have been generated.
263+
const rand = r.int(u64);
264+
var rand_lz = @clz(u64, rand | 0x7FFFFF);
265+
if (rand_lz == 41) {
266+
rand_lz += @clz(u64, r.int(u64));
267+
if (rand_lz == 41 + 64) {
268+
// It is astronomically unlikely to reach this point.
269+
rand_lz += @clz(u32, r.int(u32) | 0x7FF);
270+
}
271+
}
272+
const mantissa = @truncate(u23, rand);
273+
const exponent = @as(u32, 126 - rand_lz) << 23;
274+
return @bitCast(f32, exponent | mantissa);
259275
},
260276
f64 => {
261-
const s = r.int(u64);
262-
const repr = (0x3ff << 52) | (s >> 12);
263-
return @bitCast(f64, repr) - 1.0;
277+
// Use 52 random bits for the mantissa, and the rest for the exponent.
278+
// If all 12 bits are zero, generate additional random bits, until a
279+
// set bit is found, or 1022 bits have been generated.
280+
const rand = r.int(u64);
281+
var rand_lz: u64 = @clz(u64, rand | 0xFFFFFFFFFFFFF);
282+
if (rand_lz == 12) {
283+
while (true) {
284+
// It is astronomically unlikely for this loop to execute more than once.
285+
const addl_rand_lz = @clz(u64, r.int(u64));
286+
rand_lz += addl_rand_lz;
287+
if (addl_rand_lz != 64) {
288+
break;
289+
}
290+
if (rand_lz >= 1022) {
291+
rand_lz = 1022;
292+
break;
293+
}
294+
}
295+
}
296+
const mantissa = rand & 0xFFFFFFFFFFFFF;
297+
const exponent = (1022 - rand_lz) << 52;
298+
return @bitCast(f64, exponent | mantissa);
264299
},
265300
else => @compileError("unknown floating point type"),
266301
}
@@ -573,7 +608,7 @@ test "splitmix64 sequence" {
573608
}
574609

575610
// Actual Random helper function tests, pcg engine is assumed correct.
576-
test "Random float" {
611+
test "Random float correctness" {
577612
var prng = DefaultPrng.init(0);
578613
const random = prng.random();
579614

@@ -589,6 +624,81 @@ test "Random float" {
589624
}
590625
}
591626

627+
// Check the "astronomically unlikely" code paths.
628+
test "Random float coverage" {
629+
var prng = try Dilbert.init(&[_]u8{0});
630+
const random = prng.random();
631+
632+
const rand_f64 = random.float(f64);
633+
const rand_f32 = random.float(f32);
634+
635+
try expect(rand_f32 == 0.0);
636+
try expect(rand_f64 == 0.0);
637+
}
638+
639+
test "Random float chi-square goodness of fit" {
640+
const num_numbers = 100000;
641+
const num_buckets = 1000;
642+
643+
var f32_hist = std.AutoHashMap(u32, u32).init(std.testing.allocator);
644+
defer f32_hist.deinit();
645+
var f64_hist = std.AutoHashMap(u64, u32).init(std.testing.allocator);
646+
defer f64_hist.deinit();
647+
648+
var prng = DefaultPrng.init(0);
649+
const random = prng.random();
650+
651+
var i: usize = 0;
652+
while (i < num_numbers) : (i += 1) {
653+
const rand_f32 = random.float(f32);
654+
const rand_f64 = random.float(f64);
655+
var f32_put = try f32_hist.getOrPut(@floatToInt(u32, rand_f32 * @intToFloat(f32, num_buckets)));
656+
if (f32_put.found_existing) {
657+
f32_put.value_ptr.* += 1;
658+
} else {
659+
f32_put.value_ptr.* = 0;
660+
}
661+
var f64_put = try f64_hist.getOrPut(@floatToInt(u32, rand_f64 * @intToFloat(f64, num_buckets)));
662+
if (f64_put.found_existing) {
663+
f64_put.value_ptr.* += 1;
664+
} else {
665+
f64_put.value_ptr.* = 0;
666+
}
667+
}
668+
669+
var f32_total_variance: f64 = 0;
670+
var f64_total_variance: f64 = 0;
671+
672+
{
673+
var j: u32 = 0;
674+
while (j < num_buckets) : (j += 1) {
675+
const count = @intToFloat(f64, (if (f32_hist.get(j)) |v| v else 0));
676+
const expected = @intToFloat(f64, num_numbers) / @intToFloat(f64, num_buckets);
677+
const delta = count - expected;
678+
const variance = (delta * delta) / expected;
679+
f32_total_variance += variance;
680+
}
681+
}
682+
683+
{
684+
var j: u64 = 0;
685+
while (j < num_buckets) : (j += 1) {
686+
const count = @intToFloat(f64, (if (f64_hist.get(j)) |v| v else 0));
687+
const expected = @intToFloat(f64, num_numbers) / @intToFloat(f64, num_buckets);
688+
const delta = count - expected;
689+
const variance = (delta * delta) / expected;
690+
f64_total_variance += variance;
691+
}
692+
}
693+
694+
// Corresponds to a p-value > 0.05.
695+
// Critical value is calculated by opening a Python interpreter and running:
696+
// scipy.stats.chi2.isf(0.05, num_buckets - 1)
697+
const critical_value = 1073.6426506574246;
698+
try expect(f32_total_variance < critical_value);
699+
try expect(f64_total_variance < critical_value);
700+
}
701+
592702
test "Random shuffle" {
593703
var prng = DefaultPrng.init(0);
594704
const random = prng.random();

lib/std/rand/Dilbert.zig

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
//! Dilbert PRNG
2+
//! Do not use this PRNG! It is meant to be predictable, for the purposes of test reproducibility and coverage.
3+
//! Its output is just a repeat of a user-specified byte pattern.
4+
//! Name is a reference to this comic: https://dilbert.com/strip/2001-10-25
5+
6+
const std = @import("std");
7+
const Random = std.rand.Random;
8+
const math = std.math;
9+
const Dilbert = @This();
10+
11+
pattern: []const u8 = undefined,
12+
curr_idx: usize = 0,
13+
14+
pub fn init(pattern: []const u8) !Dilbert {
15+
if (pattern.len == 0)
16+
return error.EmptyPattern;
17+
var self = Dilbert{};
18+
self.pattern = pattern;
19+
self.curr_idx = 0;
20+
return self;
21+
}
22+
23+
pub fn random(self: *Dilbert) Random {
24+
return Random.init(self, fill);
25+
}
26+
27+
pub fn fill(self: *Dilbert, buf: []u8) void {
28+
for (buf) |*byte| {
29+
byte.* = self.pattern[self.curr_idx];
30+
self.curr_idx = (self.curr_idx + 1) % self.pattern.len;
31+
}
32+
}
33+
34+
test "Dilbert fill" {
35+
var r = try Dilbert.init("9nine");
36+
37+
const seq = [_]u64{
38+
0x396E696E65396E69,
39+
0x6E65396E696E6539,
40+
0x6E696E65396E696E,
41+
0x65396E696E65396E,
42+
0x696E65396E696E65,
43+
};
44+
45+
for (seq) |s| {
46+
var buf0: [8]u8 = undefined;
47+
var buf1: [8]u8 = undefined;
48+
std.mem.writeIntBig(u64, &buf0, s);
49+
r.fill(&buf1);
50+
try std.testing.expect(std.mem.eql(u8, buf0[0..], buf1[0..]));
51+
}
52+
}

0 commit comments

Comments
 (0)