Skip to content

Commit ed63d6c

Browse files
authored
Merge pull request #10428 from mrakh/rand_float_improvement
Improve stdlib's random float generation
2 parents e0a514d + 7bedeb9 commit ed63d6c

File tree

2 files changed

+489
-326
lines changed

2 files changed

+489
-326
lines changed

lib/std/rand.zig

Lines changed: 42 additions & 326 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
const std = @import("std.zig");
1010
const builtin = @import("builtin");
1111
const assert = std.debug.assert;
12-
const expect = std.testing.expect;
13-
const expectEqual = std.testing.expectEqual;
1412
const mem = std.mem;
1513
const math = std.math;
1614
const ziggurat = @import("rand/ziggurat.zig");
@@ -249,18 +247,51 @@ pub const Random = struct {
249247

250248
/// Return a floating point value evenly distributed in the range [0, 1).
251249
pub fn float(r: Random, comptime T: type) T {
252-
// Generate a uniform value between [1, 2) and scale down to [0, 1).
253-
// Note: The lowest mantissa bit is always set to 0 so we only use half the available range.
250+
// Generate a uniformly random value between for the mantissa.
251+
// Then generate an exponentially biased random value for the exponent.
252+
// Over the previous method, this has the advantage of being able to
253+
// represent every possible value in the available range.
254254
switch (T) {
255255
f32 => {
256-
const s = r.int(u32);
257-
const repr = (0x7f << 23) | (s >> 9);
258-
return @bitCast(f32, repr) - 1.0;
256+
// Use 23 random bits for the mantissa, and the rest for the exponent.
257+
// If all 41 bits are zero, generate additional random bits, until a
258+
// set bit is found, or 126 bits have been generated.
259+
const rand = r.int(u64);
260+
var rand_lz = @clz(u64, rand | 0x7FFFFF);
261+
if (rand_lz == 41) {
262+
rand_lz += @clz(u64, r.int(u64));
263+
if (rand_lz == 41 + 64) {
264+
// It is astronomically unlikely to reach this point.
265+
rand_lz += @clz(u32, r.int(u32) | 0x7FF);
266+
}
267+
}
268+
const mantissa = @truncate(u23, rand);
269+
const exponent = @as(u32, 126 - rand_lz) << 23;
270+
return @bitCast(f32, exponent | mantissa);
259271
},
260272
f64 => {
261-
const s = r.int(u64);
262-
const repr = (0x3ff << 52) | (s >> 12);
263-
return @bitCast(f64, repr) - 1.0;
273+
// Use 52 random bits for the mantissa, and the rest for the exponent.
274+
// If all 12 bits are zero, generate additional random bits, until a
275+
// set bit is found, or 1022 bits have been generated.
276+
const rand = r.int(u64);
277+
var rand_lz: u64 = @clz(u64, rand | 0xFFFFFFFFFFFFF);
278+
if (rand_lz == 12) {
279+
while (true) {
280+
// It is astronomically unlikely for this loop to execute more than once.
281+
const addl_rand_lz = @clz(u64, r.int(u64));
282+
rand_lz += addl_rand_lz;
283+
if (addl_rand_lz != 64) {
284+
break;
285+
}
286+
if (rand_lz >= 1022) {
287+
rand_lz = 1022;
288+
break;
289+
}
290+
}
291+
}
292+
const mantissa = rand & 0xFFFFFFFFFFFFF;
293+
const exponent = (1022 - rand_lz) << 52;
294+
return @bitCast(f64, exponent | mantissa);
264295
},
265296
else => @compileError("unknown floating point type"),
266297
}
@@ -319,221 +350,6 @@ pub fn limitRangeBiased(comptime T: type, random_int: T, less_than: T) T {
319350
return @intCast(T, m >> bits);
320351
}
321352

322-
const SequentialPrng = struct {
323-
const Self = @This();
324-
next_value: u8,
325-
326-
pub fn init() Self {
327-
return Self{
328-
.next_value = 0,
329-
};
330-
}
331-
332-
pub fn random(self: *Self) Random {
333-
return Random.init(self, fill);
334-
}
335-
336-
pub fn fill(self: *Self, buf: []u8) void {
337-
for (buf) |*b| {
338-
b.* = self.next_value;
339-
}
340-
self.next_value +%= 1;
341-
}
342-
};
343-
344-
test "Random int" {
345-
try testRandomInt();
346-
comptime try testRandomInt();
347-
}
348-
fn testRandomInt() !void {
349-
var rng = SequentialPrng.init();
350-
const random = rng.random();
351-
352-
try expect(random.int(u0) == 0);
353-
354-
rng.next_value = 0;
355-
try expect(random.int(u1) == 0);
356-
try expect(random.int(u1) == 1);
357-
try expect(random.int(u2) == 2);
358-
try expect(random.int(u2) == 3);
359-
try expect(random.int(u2) == 0);
360-
361-
rng.next_value = 0xff;
362-
try expect(random.int(u8) == 0xff);
363-
rng.next_value = 0x11;
364-
try expect(random.int(u8) == 0x11);
365-
366-
rng.next_value = 0xff;
367-
try expect(random.int(u32) == 0xffffffff);
368-
rng.next_value = 0x11;
369-
try expect(random.int(u32) == 0x11111111);
370-
371-
rng.next_value = 0xff;
372-
try expect(random.int(i32) == -1);
373-
rng.next_value = 0x11;
374-
try expect(random.int(i32) == 0x11111111);
375-
376-
rng.next_value = 0xff;
377-
try expect(random.int(i8) == -1);
378-
rng.next_value = 0x11;
379-
try expect(random.int(i8) == 0x11);
380-
381-
rng.next_value = 0xff;
382-
try expect(random.int(u33) == 0x1ffffffff);
383-
rng.next_value = 0xff;
384-
try expect(random.int(i1) == -1);
385-
rng.next_value = 0xff;
386-
try expect(random.int(i2) == -1);
387-
rng.next_value = 0xff;
388-
try expect(random.int(i33) == -1);
389-
}
390-
391-
test "Random boolean" {
392-
try testRandomBoolean();
393-
comptime try testRandomBoolean();
394-
}
395-
fn testRandomBoolean() !void {
396-
var rng = SequentialPrng.init();
397-
const random = rng.random();
398-
399-
try expect(random.boolean() == false);
400-
try expect(random.boolean() == true);
401-
try expect(random.boolean() == false);
402-
try expect(random.boolean() == true);
403-
}
404-
405-
test "Random enum" {
406-
try testRandomEnumValue();
407-
comptime try testRandomEnumValue();
408-
}
409-
fn testRandomEnumValue() !void {
410-
const TestEnum = enum {
411-
First,
412-
Second,
413-
Third,
414-
};
415-
var rng = SequentialPrng.init();
416-
const random = rng.random();
417-
rng.next_value = 0;
418-
try expect(random.enumValue(TestEnum) == TestEnum.First);
419-
try expect(random.enumValue(TestEnum) == TestEnum.First);
420-
try expect(random.enumValue(TestEnum) == TestEnum.First);
421-
}
422-
423-
test "Random intLessThan" {
424-
@setEvalBranchQuota(10000);
425-
try testRandomIntLessThan();
426-
comptime try testRandomIntLessThan();
427-
}
428-
fn testRandomIntLessThan() !void {
429-
var rng = SequentialPrng.init();
430-
const random = rng.random();
431-
432-
rng.next_value = 0xff;
433-
try expect(random.uintLessThan(u8, 4) == 3);
434-
try expect(rng.next_value == 0);
435-
try expect(random.uintLessThan(u8, 4) == 0);
436-
try expect(rng.next_value == 1);
437-
438-
rng.next_value = 0;
439-
try expect(random.uintLessThan(u64, 32) == 0);
440-
441-
// trigger the bias rejection code path
442-
rng.next_value = 0;
443-
try expect(random.uintLessThan(u8, 3) == 0);
444-
// verify we incremented twice
445-
try expect(rng.next_value == 2);
446-
447-
rng.next_value = 0xff;
448-
try expect(random.intRangeLessThan(u8, 0, 0x80) == 0x7f);
449-
rng.next_value = 0xff;
450-
try expect(random.intRangeLessThan(u8, 0x7f, 0xff) == 0xfe);
451-
452-
rng.next_value = 0xff;
453-
try expect(random.intRangeLessThan(i8, 0, 0x40) == 0x3f);
454-
rng.next_value = 0xff;
455-
try expect(random.intRangeLessThan(i8, -0x40, 0x40) == 0x3f);
456-
rng.next_value = 0xff;
457-
try expect(random.intRangeLessThan(i8, -0x80, 0) == -1);
458-
459-
rng.next_value = 0xff;
460-
try expect(random.intRangeLessThan(i3, -4, 0) == -1);
461-
rng.next_value = 0xff;
462-
try expect(random.intRangeLessThan(i3, -2, 2) == 1);
463-
}
464-
465-
test "Random intAtMost" {
466-
@setEvalBranchQuota(10000);
467-
try testRandomIntAtMost();
468-
comptime try testRandomIntAtMost();
469-
}
470-
fn testRandomIntAtMost() !void {
471-
var rng = SequentialPrng.init();
472-
const random = rng.random();
473-
474-
rng.next_value = 0xff;
475-
try expect(random.uintAtMost(u8, 3) == 3);
476-
try expect(rng.next_value == 0);
477-
try expect(random.uintAtMost(u8, 3) == 0);
478-
479-
// trigger the bias rejection code path
480-
rng.next_value = 0;
481-
try expect(random.uintAtMost(u8, 2) == 0);
482-
// verify we incremented twice
483-
try expect(rng.next_value == 2);
484-
485-
rng.next_value = 0xff;
486-
try expect(random.intRangeAtMost(u8, 0, 0x7f) == 0x7f);
487-
rng.next_value = 0xff;
488-
try expect(random.intRangeAtMost(u8, 0x7f, 0xfe) == 0xfe);
489-
490-
rng.next_value = 0xff;
491-
try expect(random.intRangeAtMost(i8, 0, 0x3f) == 0x3f);
492-
rng.next_value = 0xff;
493-
try expect(random.intRangeAtMost(i8, -0x40, 0x3f) == 0x3f);
494-
rng.next_value = 0xff;
495-
try expect(random.intRangeAtMost(i8, -0x80, -1) == -1);
496-
497-
rng.next_value = 0xff;
498-
try expect(random.intRangeAtMost(i3, -4, -1) == -1);
499-
rng.next_value = 0xff;
500-
try expect(random.intRangeAtMost(i3, -2, 1) == 1);
501-
502-
try expect(random.uintAtMost(u0, 0) == 0);
503-
}
504-
505-
test "Random Biased" {
506-
var prng = DefaultPrng.init(0);
507-
const random = prng.random();
508-
// Not thoroughly checking the logic here.
509-
// Just want to execute all the paths with different types.
510-
511-
try expect(random.uintLessThanBiased(u1, 1) == 0);
512-
try expect(random.uintLessThanBiased(u32, 10) < 10);
513-
try expect(random.uintLessThanBiased(u64, 20) < 20);
514-
515-
try expect(random.uintAtMostBiased(u0, 0) == 0);
516-
try expect(random.uintAtMostBiased(u1, 0) <= 0);
517-
try expect(random.uintAtMostBiased(u32, 10) <= 10);
518-
try expect(random.uintAtMostBiased(u64, 20) <= 20);
519-
520-
try expect(random.intRangeLessThanBiased(u1, 0, 1) == 0);
521-
try expect(random.intRangeLessThanBiased(i1, -1, 0) == -1);
522-
try expect(random.intRangeLessThanBiased(u32, 10, 20) >= 10);
523-
try expect(random.intRangeLessThanBiased(i32, 10, 20) >= 10);
524-
try expect(random.intRangeLessThanBiased(u64, 20, 40) >= 20);
525-
try expect(random.intRangeLessThanBiased(i64, 20, 40) >= 20);
526-
527-
// uncomment for broken module error:
528-
//expect(random.intRangeAtMostBiased(u0, 0, 0) == 0);
529-
try expect(random.intRangeAtMostBiased(u1, 0, 1) >= 0);
530-
try expect(random.intRangeAtMostBiased(i1, -1, 0) >= -1);
531-
try expect(random.intRangeAtMostBiased(u32, 10, 20) >= 10);
532-
try expect(random.intRangeAtMostBiased(i32, 10, 20) >= 10);
533-
try expect(random.intRangeAtMostBiased(u64, 20, 40) >= 20);
534-
try expect(random.intRangeAtMostBiased(i64, 20, 40) >= 20);
535-
}
536-
537353
// Generator to extend 64-bit seed values into longer sequences.
538354
//
539355
// The number of cycles is thus limited to 64-bits regardless of the engine, but this
@@ -555,107 +371,7 @@ pub const SplitMix64 = struct {
555371
}
556372
};
557373

558-
test "splitmix64 sequence" {
559-
var r = SplitMix64.init(0xaeecf86f7878dd75);
560-
561-
const seq = [_]u64{
562-
0x5dbd39db0178eb44,
563-
0xa9900fb66b397da3,
564-
0x5c1a28b1aeebcf5c,
565-
0x64a963238f776912,
566-
0xc6d4177b21d1c0ab,
567-
0xb2cbdbdb5ea35394,
568-
};
569-
570-
for (seq) |s| {
571-
try expect(s == r.next());
572-
}
573-
}
574-
575-
// Actual Random helper function tests, pcg engine is assumed correct.
576-
test "Random float" {
577-
var prng = DefaultPrng.init(0);
578-
const random = prng.random();
579-
580-
var i: usize = 0;
581-
while (i < 1000) : (i += 1) {
582-
const val1 = random.float(f32);
583-
try expect(val1 >= 0.0);
584-
try expect(val1 < 1.0);
585-
586-
const val2 = random.float(f64);
587-
try expect(val2 >= 0.0);
588-
try expect(val2 < 1.0);
589-
}
590-
}
591-
592-
test "Random shuffle" {
593-
var prng = DefaultPrng.init(0);
594-
const random = prng.random();
595-
596-
var seq = [_]u8{ 0, 1, 2, 3, 4 };
597-
var seen = [_]bool{false} ** 5;
598-
599-
var i: usize = 0;
600-
while (i < 1000) : (i += 1) {
601-
random.shuffle(u8, seq[0..]);
602-
seen[seq[0]] = true;
603-
try expect(sumArray(seq[0..]) == 10);
604-
}
605-
606-
// we should see every entry at the head at least once
607-
for (seen) |e| {
608-
try expect(e == true);
609-
}
610-
}
611-
612-
fn sumArray(s: []const u8) u32 {
613-
var r: u32 = 0;
614-
for (s) |e|
615-
r += e;
616-
return r;
617-
}
618-
619-
test "Random range" {
620-
var prng = DefaultPrng.init(0);
621-
const random = prng.random();
622-
623-
try testRange(random, -4, 3);
624-
try testRange(random, -4, -1);
625-
try testRange(random, 10, 14);
626-
try testRange(random, -0x80, 0x7f);
627-
}
628-
629-
fn testRange(r: Random, start: i8, end: i8) !void {
630-
try testRangeBias(r, start, end, true);
631-
try testRangeBias(r, start, end, false);
632-
}
633-
fn testRangeBias(r: Random, start: i8, end: i8, biased: bool) !void {
634-
const count = @intCast(usize, @as(i32, end) - @as(i32, start));
635-
var values_buffer = [_]bool{false} ** 0x100;
636-
const values = values_buffer[0..count];
637-
var i: usize = 0;
638-
while (i < count) {
639-
const value: i32 = if (biased) r.intRangeLessThanBiased(i8, start, end) else r.intRangeLessThan(i8, start, end);
640-
const index = @intCast(usize, value - start);
641-
if (!values[index]) {
642-
i += 1;
643-
values[index] = true;
644-
}
645-
}
646-
}
647-
648-
test "CSPRNG" {
649-
var secret_seed: [DefaultCsprng.secret_seed_length]u8 = undefined;
650-
std.crypto.random.bytes(&secret_seed);
651-
var csprng = DefaultCsprng.init(secret_seed);
652-
const random = csprng.random();
653-
const a = random.int(u64);
654-
const b = random.int(u64);
655-
const c = random.int(u64);
656-
try expect(a ^ b ^ c != 0);
657-
}
658-
659374
test {
660375
std.testing.refAllDecls(@This());
376+
_ = @import("rand/test.zig");
661377
}

0 commit comments

Comments
 (0)