|
9 | 9 | use super::{SampleBorrow, SampleUniform, UniformSampler};
|
10 | 10 | use crate::distributions::utils::WideningMultiply;
|
11 | 11 | use crate::Rng;
|
12 |
| -#[cfg(feature = "simd_support")] use packed_simd::*; |
13 | 12 | #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize};
|
14 | 13 |
|
15 | 14 | /// The back-end implementing [`UniformSampler`] for integer types.
|
@@ -49,9 +48,10 @@ use crate::Rng;
|
49 | 48 | #[derive(Clone, Copy, Debug, PartialEq)]
|
50 | 49 | #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
51 | 50 | pub struct UniformInt<X> {
|
52 |
| - low: X, |
53 |
| - range: X, |
54 |
| - z: X, // either ints_to_reject or zone depending on implementation |
| 51 | + // HACK: fields are pub(crate) |
| 52 | + pub(crate) low: X, |
| 53 | + pub(crate) range: X, |
| 54 | + pub(crate) z: X, // either ints_to_reject or zone depending on implementation |
55 | 55 | }
|
56 | 56 |
|
57 | 57 | macro_rules! uniform_int_impl {
|
@@ -202,151 +202,6 @@ uniform_int_impl! { u64, u64, u64 }
|
202 | 202 | uniform_int_impl! { usize, usize, usize }
|
203 | 203 | uniform_int_impl! { u128, u128, u128 }
|
204 | 204 |
|
205 |
| -#[cfg(feature = "simd_support")] |
206 |
| -macro_rules! uniform_simd_int_impl { |
207 |
| - ($ty:ident, $unsigned:ident, $u_scalar:ident) => { |
208 |
| - // The "pick the largest zone that can fit in an `u32`" optimization |
209 |
| - // is less useful here. Multiple lanes complicate things, we don't |
210 |
| - // know the PRNG's minimal output size, and casting to a larger vector |
211 |
| - // is generally a bad idea for SIMD performance. The user can still |
212 |
| - // implement it manually. |
213 |
| - |
214 |
| - // TODO: look into `Uniform::<u32x4>::new(0u32, 100)` functionality |
215 |
| - // perhaps `impl SampleUniform for $u_scalar`? |
216 |
| - impl SampleUniform for $ty { |
217 |
| - type Sampler = UniformInt<$ty>; |
218 |
| - } |
219 |
| - |
220 |
| - impl UniformSampler for UniformInt<$ty> { |
221 |
| - type X = $ty; |
222 |
| - |
223 |
| - #[inline] // if the range is constant, this helps LLVM to do the |
224 |
| - // calculations at compile-time. |
225 |
| - fn new<B1, B2>(low_b: B1, high_b: B2) -> Self |
226 |
| - where B1: SampleBorrow<Self::X> + Sized, |
227 |
| - B2: SampleBorrow<Self::X> + Sized |
228 |
| - { |
229 |
| - let low = *low_b.borrow(); |
230 |
| - let high = *high_b.borrow(); |
231 |
| - assert!(low.lt(high).all(), "Uniform::new called with `low >= high`"); |
232 |
| - UniformSampler::new_inclusive(low, high - 1) |
233 |
| - } |
234 |
| - |
235 |
| - #[inline] // if the range is constant, this helps LLVM to do the |
236 |
| - // calculations at compile-time. |
237 |
| - fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Self |
238 |
| - where B1: SampleBorrow<Self::X> + Sized, |
239 |
| - B2: SampleBorrow<Self::X> + Sized |
240 |
| - { |
241 |
| - let low = *low_b.borrow(); |
242 |
| - let high = *high_b.borrow(); |
243 |
| - assert!(low.le(high).all(), |
244 |
| - "Uniform::new_inclusive called with `low > high`"); |
245 |
| - let unsigned_max = ::core::$u_scalar::MAX; |
246 |
| - |
247 |
| - // NOTE: these may need to be replaced with explicitly |
248 |
| - // wrapping operations if `packed_simd` changes |
249 |
| - let range: $unsigned = ((high - low) + 1).cast(); |
250 |
| - // `% 0` will panic at runtime. |
251 |
| - let not_full_range = range.gt($unsigned::splat(0)); |
252 |
| - // replacing 0 with `unsigned_max` allows a faster `select` |
253 |
| - // with bitwise OR |
254 |
| - let modulo = not_full_range.select(range, $unsigned::splat(unsigned_max)); |
255 |
| - // wrapping addition |
256 |
| - let ints_to_reject = (unsigned_max - range + 1) % modulo; |
257 |
| - // When `range` is 0, `lo` of `v.wmul(range)` will always be |
258 |
| - // zero which means only one sample is needed. |
259 |
| - let zone = unsigned_max - ints_to_reject; |
260 |
| - |
261 |
| - UniformInt { |
262 |
| - low, |
263 |
| - // These are really $unsigned values, but store as $ty: |
264 |
| - range: range.cast(), |
265 |
| - z: zone.cast(), |
266 |
| - } |
267 |
| - } |
268 |
| - |
269 |
| - fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X { |
270 |
| - let range: $unsigned = self.range.cast(); |
271 |
| - let zone: $unsigned = self.z.cast(); |
272 |
| - |
273 |
| - // This might seem very slow, generating a whole new |
274 |
| - // SIMD vector for every sample rejection. For most uses |
275 |
| - // though, the chance of rejection is small and provides good |
276 |
| - // general performance. With multiple lanes, that chance is |
277 |
| - // multiplied. To mitigate this, we replace only the lanes of |
278 |
| - // the vector which fail, iteratively reducing the chance of |
279 |
| - // rejection. The replacement method does however add a little |
280 |
| - // overhead. Benchmarking or calculating probabilities might |
281 |
| - // reveal contexts where this replacement method is slower. |
282 |
| - let mut v: $unsigned = rng.gen(); |
283 |
| - loop { |
284 |
| - let (hi, lo) = v.wmul(range); |
285 |
| - let mask = lo.le(zone); |
286 |
| - if mask.all() { |
287 |
| - let hi: $ty = hi.cast(); |
288 |
| - // wrapping addition |
289 |
| - let result = self.low + hi; |
290 |
| - // `select` here compiles to a blend operation |
291 |
| - // When `range.eq(0).none()` the compare and blend |
292 |
| - // operations are avoided. |
293 |
| - let v: $ty = v.cast(); |
294 |
| - return range.gt($unsigned::splat(0)).select(result, v); |
295 |
| - } |
296 |
| - // Replace only the failing lanes |
297 |
| - v = mask.select(v, rng.gen()); |
298 |
| - } |
299 |
| - } |
300 |
| - } |
301 |
| - }; |
302 |
| - |
303 |
| - // bulk implementation |
304 |
| - ($(($unsigned:ident, $signed:ident),)+ $u_scalar:ident) => { |
305 |
| - $( |
306 |
| - uniform_simd_int_impl!($unsigned, $unsigned, $u_scalar); |
307 |
| - uniform_simd_int_impl!($signed, $unsigned, $u_scalar); |
308 |
| - )+ |
309 |
| - }; |
310 |
| -} |
311 |
| - |
312 |
| -#[cfg(feature = "simd_support")] |
313 |
| -uniform_simd_int_impl! { |
314 |
| - (u64x2, i64x2), |
315 |
| - (u64x4, i64x4), |
316 |
| - (u64x8, i64x8), |
317 |
| - u64 |
318 |
| -} |
319 |
| - |
320 |
| -#[cfg(feature = "simd_support")] |
321 |
| -uniform_simd_int_impl! { |
322 |
| - (u32x2, i32x2), |
323 |
| - (u32x4, i32x4), |
324 |
| - (u32x8, i32x8), |
325 |
| - (u32x16, i32x16), |
326 |
| - u32 |
327 |
| -} |
328 |
| - |
329 |
| -#[cfg(feature = "simd_support")] |
330 |
| -uniform_simd_int_impl! { |
331 |
| - (u16x2, i16x2), |
332 |
| - (u16x4, i16x4), |
333 |
| - (u16x8, i16x8), |
334 |
| - (u16x16, i16x16), |
335 |
| - (u16x32, i16x32), |
336 |
| - u16 |
337 |
| -} |
338 |
| - |
339 |
| -#[cfg(feature = "simd_support")] |
340 |
| -uniform_simd_int_impl! { |
341 |
| - (u8x2, i8x2), |
342 |
| - (u8x4, i8x4), |
343 |
| - (u8x8, i8x8), |
344 |
| - (u8x16, i8x16), |
345 |
| - (u8x32, i8x32), |
346 |
| - (u8x64, i8x64), |
347 |
| - u8 |
348 |
| -} |
349 |
| - |
350 | 205 | #[cfg(test)]
|
351 | 206 | mod tests {
|
352 | 207 | use super::*;
|
@@ -441,34 +296,8 @@ mod tests {
|
441 | 296 | |x, y| x < y
|
442 | 297 | );)*
|
443 | 298 | }};
|
444 |
| - |
445 |
| - // simd bulk |
446 |
| - ($($ty:ident),* => $scalar:ident) => {{ |
447 |
| - $(t!( |
448 |
| - $ty, |
449 |
| - [ |
450 |
| - ($ty::splat(0), $ty::splat(10)), |
451 |
| - ($ty::splat(10), $ty::splat(127)), |
452 |
| - ($ty::splat($scalar::MIN), $ty::splat($scalar::MAX)), |
453 |
| - ], |
454 |
| - |x: $ty, y| x.le(y).all(), |
455 |
| - |x: $ty, y| x.lt(y).all() |
456 |
| - );)* |
457 |
| - }}; |
458 | 299 | }
|
459 | 300 | t!(i8, i16, i32, i64, isize, u8, u16, u32, u64, usize, i128, u128);
|
460 |
| - |
461 |
| - #[cfg(feature = "simd_support")] |
462 |
| - { |
463 |
| - t!(u8x2, u8x4, u8x8, u8x16, u8x32, u8x64 => u8); |
464 |
| - t!(i8x2, i8x4, i8x8, i8x16, i8x32, i8x64 => i8); |
465 |
| - t!(u16x2, u16x4, u16x8, u16x16, u16x32 => u16); |
466 |
| - t!(i16x2, i16x4, i16x8, i16x16, i16x32 => i16); |
467 |
| - t!(u32x2, u32x4, u32x8, u32x16 => u32); |
468 |
| - t!(i32x2, i32x4, i32x8, i32x16 => i32); |
469 |
| - t!(u64x2, u64x4, u64x8 => u64); |
470 |
| - t!(i64x2, i64x4, i64x8 => i64); |
471 |
| - } |
472 | 301 | }
|
473 | 302 |
|
474 | 303 | #[test]
|
|
0 commit comments