Skip to content

Commit f080e47

Browse files
committed
Move UniformInt SIMD implementations to new module
1 parent 9604f2a commit f080e47

File tree

3 files changed

+234
-175
lines changed

3 files changed

+234
-175
lines changed

src/distributions/uniform.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ use crate::distributions::utils::Float;
120120
mod uniform_float;
121121
mod uniform_int;
122122
mod uniform_other;
123+
#[cfg(feature = "simd_support")] mod uniform_simd;
123124

124125
pub use uniform_float::UniformFloat;
125126
pub use uniform_int::UniformInt;

src/distributions/uniform/uniform_int.rs

Lines changed: 4 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
use super::{SampleBorrow, SampleUniform, UniformSampler};
1010
use crate::distributions::utils::WideningMultiply;
1111
use crate::Rng;
12-
#[cfg(feature = "simd_support")] use packed_simd::*;
1312
#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize};
1413

1514
/// The back-end implementing [`UniformSampler`] for integer types.
@@ -49,9 +48,10 @@ use crate::Rng;
4948
#[derive(Clone, Copy, Debug, PartialEq)]
5049
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
5150
pub struct UniformInt<X> {
52-
low: X,
53-
range: X,
54-
z: X, // either ints_to_reject or zone depending on implementation
51+
// HACK: fields are pub(crate)
52+
pub(crate) low: X,
53+
pub(crate) range: X,
54+
pub(crate) z: X, // either ints_to_reject or zone depending on implementation
5555
}
5656

5757
macro_rules! uniform_int_impl {
@@ -202,151 +202,6 @@ uniform_int_impl! { u64, u64, u64 }
202202
uniform_int_impl! { usize, usize, usize }
203203
uniform_int_impl! { u128, u128, u128 }
204204

205-
#[cfg(feature = "simd_support")]
206-
macro_rules! uniform_simd_int_impl {
207-
($ty:ident, $unsigned:ident, $u_scalar:ident) => {
208-
// The "pick the largest zone that can fit in an `u32`" optimization
209-
// is less useful here. Multiple lanes complicate things, we don't
210-
// know the PRNG's minimal output size, and casting to a larger vector
211-
// is generally a bad idea for SIMD performance. The user can still
212-
// implement it manually.
213-
214-
// TODO: look into `Uniform::<u32x4>::new(0u32, 100)` functionality
215-
// perhaps `impl SampleUniform for $u_scalar`?
216-
impl SampleUniform for $ty {
217-
type Sampler = UniformInt<$ty>;
218-
}
219-
220-
impl UniformSampler for UniformInt<$ty> {
221-
type X = $ty;
222-
223-
#[inline] // if the range is constant, this helps LLVM to do the
224-
// calculations at compile-time.
225-
fn new<B1, B2>(low_b: B1, high_b: B2) -> Self
226-
where B1: SampleBorrow<Self::X> + Sized,
227-
B2: SampleBorrow<Self::X> + Sized
228-
{
229-
let low = *low_b.borrow();
230-
let high = *high_b.borrow();
231-
assert!(low.lt(high).all(), "Uniform::new called with `low >= high`");
232-
UniformSampler::new_inclusive(low, high - 1)
233-
}
234-
235-
#[inline] // if the range is constant, this helps LLVM to do the
236-
// calculations at compile-time.
237-
fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Self
238-
where B1: SampleBorrow<Self::X> + Sized,
239-
B2: SampleBorrow<Self::X> + Sized
240-
{
241-
let low = *low_b.borrow();
242-
let high = *high_b.borrow();
243-
assert!(low.le(high).all(),
244-
"Uniform::new_inclusive called with `low > high`");
245-
let unsigned_max = ::core::$u_scalar::MAX;
246-
247-
// NOTE: these may need to be replaced with explicitly
248-
// wrapping operations if `packed_simd` changes
249-
let range: $unsigned = ((high - low) + 1).cast();
250-
// `% 0` will panic at runtime.
251-
let not_full_range = range.gt($unsigned::splat(0));
252-
// replacing 0 with `unsigned_max` allows a faster `select`
253-
// with bitwise OR
254-
let modulo = not_full_range.select(range, $unsigned::splat(unsigned_max));
255-
// wrapping addition
256-
let ints_to_reject = (unsigned_max - range + 1) % modulo;
257-
// When `range` is 0, `lo` of `v.wmul(range)` will always be
258-
// zero which means only one sample is needed.
259-
let zone = unsigned_max - ints_to_reject;
260-
261-
UniformInt {
262-
low,
263-
// These are really $unsigned values, but store as $ty:
264-
range: range.cast(),
265-
z: zone.cast(),
266-
}
267-
}
268-
269-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
270-
let range: $unsigned = self.range.cast();
271-
let zone: $unsigned = self.z.cast();
272-
273-
// This might seem very slow, generating a whole new
274-
// SIMD vector for every sample rejection. For most uses
275-
// though, the chance of rejection is small and provides good
276-
// general performance. With multiple lanes, that chance is
277-
// multiplied. To mitigate this, we replace only the lanes of
278-
// the vector which fail, iteratively reducing the chance of
279-
// rejection. The replacement method does however add a little
280-
// overhead. Benchmarking or calculating probabilities might
281-
// reveal contexts where this replacement method is slower.
282-
let mut v: $unsigned = rng.gen();
283-
loop {
284-
let (hi, lo) = v.wmul(range);
285-
let mask = lo.le(zone);
286-
if mask.all() {
287-
let hi: $ty = hi.cast();
288-
// wrapping addition
289-
let result = self.low + hi;
290-
// `select` here compiles to a blend operation
291-
// When `range.eq(0).none()` the compare and blend
292-
// operations are avoided.
293-
let v: $ty = v.cast();
294-
return range.gt($unsigned::splat(0)).select(result, v);
295-
}
296-
// Replace only the failing lanes
297-
v = mask.select(v, rng.gen());
298-
}
299-
}
300-
}
301-
};
302-
303-
// bulk implementation
304-
($(($unsigned:ident, $signed:ident),)+ $u_scalar:ident) => {
305-
$(
306-
uniform_simd_int_impl!($unsigned, $unsigned, $u_scalar);
307-
uniform_simd_int_impl!($signed, $unsigned, $u_scalar);
308-
)+
309-
};
310-
}
311-
312-
#[cfg(feature = "simd_support")]
313-
uniform_simd_int_impl! {
314-
(u64x2, i64x2),
315-
(u64x4, i64x4),
316-
(u64x8, i64x8),
317-
u64
318-
}
319-
320-
#[cfg(feature = "simd_support")]
321-
uniform_simd_int_impl! {
322-
(u32x2, i32x2),
323-
(u32x4, i32x4),
324-
(u32x8, i32x8),
325-
(u32x16, i32x16),
326-
u32
327-
}
328-
329-
#[cfg(feature = "simd_support")]
330-
uniform_simd_int_impl! {
331-
(u16x2, i16x2),
332-
(u16x4, i16x4),
333-
(u16x8, i16x8),
334-
(u16x16, i16x16),
335-
(u16x32, i16x32),
336-
u16
337-
}
338-
339-
#[cfg(feature = "simd_support")]
340-
uniform_simd_int_impl! {
341-
(u8x2, i8x2),
342-
(u8x4, i8x4),
343-
(u8x8, i8x8),
344-
(u8x16, i8x16),
345-
(u8x32, i8x32),
346-
(u8x64, i8x64),
347-
u8
348-
}
349-
350205
#[cfg(test)]
351206
mod tests {
352207
use super::*;
@@ -441,34 +296,8 @@ mod tests {
441296
|x, y| x < y
442297
);)*
443298
}};
444-
445-
// simd bulk
446-
($($ty:ident),* => $scalar:ident) => {{
447-
$(t!(
448-
$ty,
449-
[
450-
($ty::splat(0), $ty::splat(10)),
451-
($ty::splat(10), $ty::splat(127)),
452-
($ty::splat($scalar::MIN), $ty::splat($scalar::MAX)),
453-
],
454-
|x: $ty, y| x.le(y).all(),
455-
|x: $ty, y| x.lt(y).all()
456-
);)*
457-
}};
458299
}
459300
t!(i8, i16, i32, i64, isize, u8, u16, u32, u64, usize, i128, u128);
460-
461-
#[cfg(feature = "simd_support")]
462-
{
463-
t!(u8x2, u8x4, u8x8, u8x16, u8x32, u8x64 => u8);
464-
t!(i8x2, i8x4, i8x8, i8x16, i8x32, i8x64 => i8);
465-
t!(u16x2, u16x4, u16x8, u16x16, u16x32 => u16);
466-
t!(i16x2, i16x4, i16x8, i16x16, i16x32 => i16);
467-
t!(u32x2, u32x4, u32x8, u32x16 => u32);
468-
t!(i32x2, i32x4, i32x8, i32x16 => i32);
469-
t!(u64x2, u64x4, u64x8 => u64);
470-
t!(i64x2, i64x4, i64x8 => i64);
471-
}
472301
}
473302

474303
#[test]

0 commit comments

Comments
 (0)