|
12 | 12 |
|
13 | 13 | use super::Rng;
|
14 | 14 |
|
15 |
| -// This crate is only enabled when either std or alloc is available. |
16 |
| -// BTreeMap is not as fast in tests, but better than nothing. |
17 |
| -#[cfg(feature="std")] use std::collections::HashMap; |
18 |
| -#[cfg(not(feature="std"))] use alloc::btree_map::BTreeMap; |
19 |
| - |
20 | 15 | #[cfg(not(feature="std"))] use alloc::Vec;
|
21 | 16 |
|
22 | 17 | /// Randomly sample `amount` elements from a finite iterator.
|
@@ -139,87 +134,13 @@ pub fn sample_indices<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize
|
139 | 134 | panic!("`amount` must be less than or equal to `slice.len()`");
|
140 | 135 | }
|
141 | 136 |
|
142 |
| - // We are going to have to allocate at least `amount` for the output no matter what. However, |
143 |
| - // if we use the `cached` version we will have to allocate `amount` as a HashMap as well since |
144 |
| - // it inserts an element for every loop. |
145 |
| - // |
146 |
| - // Therefore, if `amount >= length / 2` then inplace will be both faster and use less memory. |
147 |
| - // In fact, benchmarks show the inplace version is faster for length up to about 20 times |
148 |
| - // faster than amount. |
149 |
| - // |
150 |
| - // TODO: there is probably even more fine-tuning that can be done here since |
151 |
| - // `HashMap::with_capacity(amount)` probably allocates more than `amount` in practice, |
152 |
| - // and a trade off could probably be made between memory/cpu, since hashmap operations |
153 |
| - // are slower than array index swapping. |
154 |
| - if amount >= length / 20 { |
155 |
| - sample_indices_inplace(rng, length, amount) |
156 |
| - } else { |
157 |
| - sample_indices_cache(rng, length, amount) |
158 |
| - } |
159 |
| -} |
160 |
| - |
161 |
| -/// Sample an amount of indices using an inplace partial fisher yates method. |
162 |
| -/// |
163 |
| -/// This allocates the entire `length` of indices and randomizes only the first `amount`. |
164 |
| -/// It then truncates to `amount` and returns. |
165 |
| -/// |
166 |
| -/// This is better than using a `HashMap` "cache" when `amount >= length / 2` |
167 |
| -/// since it does not require allocating an extra cache and is much faster. |
168 |
| -fn sample_indices_inplace<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize> |
169 |
| - where R: Rng + ?Sized, |
170 |
| -{ |
171 |
| - debug_assert!(amount <= length); |
172 |
| - let mut indices: Vec<usize> = Vec::with_capacity(length); |
173 |
| - indices.extend(0..length); |
174 |
| - for i in 0..amount { |
175 |
| - let j: usize = rng.gen_range(i, length); |
176 |
| - indices.swap(i, j); |
177 |
| - } |
178 |
| - indices.truncate(amount); |
179 |
| - debug_assert_eq!(indices.len(), amount); |
180 |
| - indices |
181 |
| -} |
182 |
| - |
183 |
| - |
184 |
| -/// This method performs a partial fisher-yates on a range of indices using a |
185 |
| -/// `HashMap` as a cache to record potential collisions. |
186 |
| -/// |
187 |
| -/// The cache avoids allocating the entire `length` of values. This is especially useful when |
188 |
| -/// `amount <<< length`, i.e. select 3 non-repeating from `1_000_000` |
189 |
| -fn sample_indices_cache<R>( |
190 |
| - rng: &mut R, |
191 |
| - length: usize, |
192 |
| - amount: usize, |
193 |
| -) -> Vec<usize> |
194 |
| - where R: Rng + ?Sized, |
195 |
| -{ |
196 |
| - debug_assert!(amount <= length); |
197 |
| - #[cfg(feature="std")] let mut cache = HashMap::with_capacity(amount); |
198 |
| - #[cfg(not(feature="std"))] let mut cache = BTreeMap::new(); |
199 |
| - let mut out = Vec::with_capacity(amount); |
200 |
| - for i in 0..amount { |
201 |
| - let j: usize = rng.gen_range(i, length); |
202 |
| - |
203 |
| - // equiv: let tmp = slice[i]; |
204 |
| - let tmp = match cache.get(&i) { |
205 |
| - Some(e) => *e, |
206 |
| - None => i, |
207 |
| - }; |
208 |
| - |
209 |
| - // equiv: slice[i] = slice[j]; |
210 |
| - let x = match cache.get(&j) { |
211 |
| - Some(x) => *x, |
212 |
| - None => j, |
213 |
| - }; |
214 |
| - |
215 |
| - // equiv: slice[j] = tmp; |
216 |
| - cache.insert(j, tmp); |
217 |
| - |
218 |
| - // note that in the inplace version, slice[i] is automatically "returned" value |
219 |
| - out.push(x); |
| 137 | + let mut s = Vec::with_capacity(amount); |
| 138 | + for j in length - amount .. length { |
| 139 | + let t = rng.gen_range(0, j + 1); |
| 140 | + let t = if s.contains(&t) { j } else { t }; |
| 141 | + s.push( t ); |
220 | 142 | }
|
221 |
| - debug_assert_eq!(out.len(), amount); |
222 |
| - out |
| 143 | + s |
223 | 144 | }
|
224 | 145 |
|
225 | 146 | #[cfg(test)]
|
@@ -267,13 +188,9 @@ mod test {
|
267 | 188 | let v = sample_slice(&mut r, &[42, 133], 2);
|
268 | 189 | assert!(&v[..] == [42, 133] || v[..] == [133, 42]);
|
269 | 190 |
|
270 |
| - assert_eq!(&sample_indices_inplace(&mut r, 0, 0)[..], [0usize; 0]); |
271 |
| - assert_eq!(&sample_indices_inplace(&mut r, 1, 0)[..], [0usize; 0]); |
272 |
| - assert_eq!(&sample_indices_inplace(&mut r, 1, 1)[..], [0]); |
273 |
| - |
274 |
| - assert_eq!(&sample_indices_cache(&mut r, 0, 0)[..], [0usize; 0]); |
275 |
| - assert_eq!(&sample_indices_cache(&mut r, 1, 0)[..], [0usize; 0]); |
276 |
| - assert_eq!(&sample_indices_cache(&mut r, 1, 1)[..], [0]); |
| 191 | + assert_eq!(&sample_indices(&mut r, 0, 0)[..], [0usize; 0]); |
| 192 | + assert_eq!(&sample_indices(&mut r, 1, 0)[..], [0usize; 0]); |
| 193 | + assert_eq!(&sample_indices(&mut r, 1, 1)[..], [0]); |
277 | 194 |
|
278 | 195 | // Make sure lucky 777's aren't lucky
|
279 | 196 | let slice = &[42, 777];
|
@@ -304,19 +221,11 @@ mod test {
|
304 | 221 | let mut seed = [0u8; 16];
|
305 | 222 | r.fill(&mut seed);
|
306 | 223 |
|
307 |
| - // assert that the two index methods give exactly the same result |
308 |
| - let inplace = sample_indices_inplace( |
309 |
| - &mut xor_rng(seed), length, amount); |
310 |
| - let cache = sample_indices_cache( |
311 |
| - &mut xor_rng(seed), length, amount); |
312 |
| - assert_eq!(inplace, cache); |
313 |
| - |
314 | 224 | // assert the basics work
|
315 | 225 | let regular = sample_indices(
|
316 | 226 | &mut xor_rng(seed), length, amount);
|
317 | 227 | assert_eq!(regular.len(), amount);
|
318 | 228 | assert!(regular.iter().all(|e| *e < length));
|
319 |
| - assert_eq!(regular, inplace); |
320 | 229 |
|
321 | 230 | // also test that sampling the slice works
|
322 | 231 | let vec: Vec<usize> = (0..length).collect();
|
|
0 commit comments