Skip to content

Commit c94a1b3

Browse files
committed
Fix size_hint for partially consumed QueryIter
Instead of returning the total count of elements in the `QueryIter` in `size_hint`, we return the count of remaining elements in it. This Fixes #5149. This is also true of `QueryCombinationIter`. - #5149 - #5148
1 parent 5b5013d commit c94a1b3

File tree

2 files changed

+80
-34
lines changed

2 files changed

+80
-34
lines changed

crates/bevy_ecs/src/query/iter.rs

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,7 @@ where
6262
}
6363

6464
fn size_hint(&self) -> (usize, Option<usize>) {
65-
let max_size = self
66-
.query_state
67-
.matched_archetype_ids
68-
.iter()
69-
.map(|id| self.archetypes[*id].len())
70-
.sum();
71-
65+
let max_size = self.cursor.remaining(self.tables, self.archetypes);
7266
let archetype_query = F::Fetch::IS_ARCHETYPAL && QF::IS_ARCHETYPAL;
7367
let min_size = if archetype_query { max_size } else { 0 };
7468
(min_size, Some(max_size))
@@ -264,11 +258,16 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery, const K: usize> QueryCombinationIter<
264258
return None;
265259
}
266260

267-
// first, iterate from last to first until next item is found
261+
// TODO: can speed up the following code using `cursor.remaining()` instead of `next_item.is_none()`
262+
// when Q::Fetch::IS_ARCHETYPAL && F::Fetch::IS_ARCHETYPAL
263+
//
264+
// let `i` be the index of `c`, the last cursor in `self.cursors` that
265+
// returns `K-i` or more elements.
266+
// Make cursor in index `j` for all `j` in `[i, K)` a copy of `c` advanced `j-i+1` times.
267+
// If no such `c` exists, return `None`
268268
'outer: for i in (0..K).rev() {
269269
match self.cursors[i].next(self.tables, self.archetypes, self.query_state) {
270270
Some(_) => {
271-
// walk forward up to last element, propagating cursor state forward
272271
for j in (i + 1)..K {
273272
self.cursors[j] = self.cursors[j - 1].clone();
274273
match self.cursors[j].next(self.tables, self.archetypes, self.query_state) {
@@ -329,31 +328,32 @@ where
329328
}
330329

331330
fn size_hint(&self) -> (usize, Option<usize>) {
332-
if K == 0 {
333-
return (0, Some(0));
331+
// binomial coefficient: (n ; k) = n! / k!(n-k)! = (n*n-1*...*n-k+1) / k!
332+
// See https://en.wikipedia.org/wiki/Binomial_coefficient
333+
// See https://blog.plover.com/math/choose.html for implementation
334+
// It was chosen to reduce overflow potential.
335+
fn choose(n: usize, k: usize) -> Option<usize> {
336+
if k > n {
337+
return Some(0);
338+
}
339+
let ks = 1..=k;
340+
let ns = (n + 1 - k..=n).rev();
341+
ks.zip(ns)
342+
.try_fold(1_usize, |acc, (k, n)| Some(acc.checked_mul(n)? / k))
334343
}
335-
336-
let max_size: usize = self
337-
.query_state
338-
.matched_archetype_ids
344+
// sum_i=0..k choose(cursors[i].remaining, k-i)
345+
let max_combinations = self
346+
.cursors
339347
.iter()
340-
.map(|id| self.archetypes[*id].len())
341-
.sum();
342-
343-
if max_size < K {
344-
return (0, Some(0));
345-
}
346-
347-
// n! / k!(n-k)! = (n*n-1*...*n-k+1) / k!
348-
let max_combinations = (0..K)
349-
.try_fold(1usize, |n, i| n.checked_mul(max_size - i))
350-
.map(|n| {
351-
let k_factorial: usize = (1..=K).product();
352-
n / k_factorial
348+
.enumerate()
349+
.try_fold(0, |acc, (i, cursor)| {
350+
let n = cursor.remaining(self.tables, self.archetypes);
351+
Some(acc + choose(n, K - i)?)
353352
});
354353

355354
let archetype_query = F::Fetch::IS_ARCHETYPAL && Q::Fetch::IS_ARCHETYPAL;
356-
let min_combinations = if archetype_query { max_size } else { 0 };
355+
let known_max = max_combinations.unwrap_or(usize::MAX);
356+
let min_combinations = if archetype_query { known_max } else { 0 };
357357
(min_combinations, max_combinations)
358358
}
359359
}
@@ -364,11 +364,7 @@ where
364364
F: WorldQuery + ArchetypeFilter,
365365
{
366366
fn len(&self) -> usize {
367-
self.query_state
368-
.matched_archetype_ids
369-
.iter()
370-
.map(|id| self.archetypes[*id].len())
371-
.sum()
367+
self.size_hint().0
372368
}
373369
}
374370

@@ -473,6 +469,18 @@ where
473469
}
474470
}
475471

472+
/// How many values will this cursor return?
473+
fn remaining(&self, tables: &'w Tables, archetypes: &'w Archetypes) -> usize {
474+
let remaining_matched: usize = if Self::IS_DENSE {
475+
let ids = self.table_id_iter.clone();
476+
ids.map(|id| tables[*id].len()).sum()
477+
} else {
478+
let ids = self.archetype_id_iter.clone();
479+
ids.map(|id| archetypes[*id].len()).sum()
480+
};
481+
remaining_matched + self.current_len - self.current_index
482+
}
483+
476484
// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
477485
// QueryIterationCursor, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual
478486
/// # Safety

crates/bevy_ecs/src/query/mod.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,18 +144,30 @@ mod tests {
144144
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
145145
assert_eq!(values.iter(&world).len(), n);
146146
assert_eq!(values.iter(&world).count(), n);
147+
let mut iterator = values.iter(&world);
148+
let _ = iterator.next();
149+
assert_eq!(iterator.len(), n - 1);
150+
147151
let mut values = world.query_filtered::<&A, Or<(With<B>, Without<C>)>>();
148152
let n = 7;
149153
assert_eq!(values.iter(&world).size_hint().0, n);
150154
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
151155
assert_eq!(values.iter(&world).len(), n);
152156
assert_eq!(values.iter(&world).count(), n);
157+
let mut iterator = values.iter(&world);
158+
let _ = iterator.next();
159+
assert_eq!(iterator.len(), n - 1);
160+
153161
let mut values = world.query_filtered::<&A, Or<(Without<B>, With<C>)>>();
154162
let n = 8;
155163
assert_eq!(values.iter(&world).size_hint().0, n);
156164
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
157165
assert_eq!(values.iter(&world).len(), n);
158166
assert_eq!(values.iter(&world).count(), n);
167+
let mut iterator = values.iter(&world);
168+
let _ = iterator.next();
169+
assert_eq!(iterator.len(), n - 1);
170+
159171
let mut values = world.query_filtered::<&A, Or<(Without<B>, Without<C>)>>();
160172
let n = 9;
161173
assert_eq!(values.iter(&world).size_hint().0, n);
@@ -169,6 +181,12 @@ mod tests {
169181
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
170182
assert_eq!(values.iter(&world).len(), n);
171183
assert_eq!(values.iter(&world).count(), n);
184+
let mut iterator = values.iter(&world);
185+
let _ = iterator.next();
186+
assert_eq!(iterator.len(), 0);
187+
let _ = iterator.next();
188+
assert_eq!(iterator.len(), 0);
189+
172190
let mut values = world.query_filtered::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>();
173191
let n = 6;
174192
assert_eq!(values.iter(&world).size_hint().0, n);
@@ -218,6 +236,18 @@ mod tests {
218236
assert_eq!(a_query.iter_combinations::<128>(w).count(), 0);
219237
assert_eq!(a_query.iter_combinations::<128>(w).size_hint().1, Some(0));
220238

239+
let mut combination = a_query.iter_combinations::<2>(w);
240+
let mut expected = 6;
241+
for _ in 0..6 {
242+
let _ = combination.next();
243+
expected -= 1;
244+
assert_eq!(combination.size_hint().1, Some(expected));
245+
}
246+
247+
let mut combination = a_query.iter_combinations::<4>(w);
248+
let _ = combination.next();
249+
assert_eq!(combination.size_hint().1, Some(0));
250+
221251
let values: Vec<[&A; 2]> = world.query::<&A>().iter_combinations(&world).collect();
222252
assert_eq!(
223253
values,
@@ -299,6 +329,10 @@ mod tests {
299329
assert_eq!(a_with_b.iter_combinations::<128>(w).count(), 0);
300330
assert_eq!(a_with_b.iter_combinations::<128>(w).size_hint().1, Some(0));
301331

332+
let mut combination = a_with_b.iter_combinations::<1>(w);
333+
_ = combination.next();
334+
assert_eq!(combination.size_hint().1, Some(0));
335+
302336
let mut a_wout_b = world.query_filtered::<&A, Without<B>>();
303337
let w = &world;
304338
assert_eq!(a_wout_b.iter_combinations::<0>(w).count(), 0);
@@ -316,6 +350,10 @@ mod tests {
316350
assert_eq!(a_wout_b.iter_combinations::<128>(w).count(), 0);
317351
assert_eq!(a_wout_b.iter_combinations::<128>(w).size_hint().1, Some(0));
318352

353+
let mut combination = a_wout_b.iter_combinations::<2>(w);
354+
_ = combination.next();
355+
assert_eq!(combination.size_hint().1, Some(2));
356+
319357
let values: HashSet<[&A; 2]> = a_wout_b.iter_combinations(&world).collect();
320358
assert_eq!(
321359
values,

0 commit comments

Comments
 (0)