Skip to content

Commit 4f857a3

Browse files
committed
Fix zero handling in AxisChunksIter/Mut
Now, chunk size of zero and axis length of zero are handled correctly.
1 parent a0130ad commit 4f857a3

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

src/impl_methods.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,7 +1005,7 @@ where
10051005
/// The last view may have less elements if `size` does not divide
10061006
/// the axis' dimension.
10071007
///
1008-
/// **Panics** if `axis` is out of bounds.
1008+
/// **Panics** if `axis` is out of bounds or if `size` is zero.
10091009
///
10101010
/// ```
10111011
/// use ndarray::Array;
@@ -1036,7 +1036,7 @@ where
10361036
///
10371037
/// Iterator element is `ArrayViewMut<A, D>`
10381038
///
1039-
/// **Panics** if `axis` is out of bounds.
1039+
/// **Panics** if `axis` is out of bounds or if `size` is zero.
10401040
pub fn axis_chunks_iter_mut(&mut self, axis: Axis, size: usize) -> AxisChunksIterMut<A, D>
10411041
where
10421042
S: DataMut,

src/iterators/mod.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,21 +1206,28 @@ clone_bounds!(
12061206
///
12071207
/// Returns an axis iterator with the correct stride to move between chunks,
12081208
/// the number of chunks, and the shape of the last chunk.
1209+
///
1210+
/// **Panics** if `size == 0`.
12091211
fn chunk_iter_parts<A, D: Dimension>(
12101212
v: ArrayView<A, D>,
12111213
axis: Axis,
12121214
size: usize,
12131215
) -> (AxisIterCore<A, D>, usize, D) {
1216+
assert_ne!(size, 0, "Chunk size must be nonzero.");
12141217
let axis_len = v.len_of(axis);
1215-
let size = if size > axis_len { axis_len } else { size };
12161218
let n_whole_chunks = axis_len / size;
12171219
let chunk_remainder = axis_len % size;
12181220
let iter_len = if chunk_remainder == 0 {
12191221
n_whole_chunks
12201222
} else {
12211223
n_whole_chunks + 1
12221224
};
1223-
let stride = v.stride_of(axis) * size as isize;
1225+
let stride = if n_whole_chunks == 0 {
1226+
// This case avoids potential overflow when `size > axis_len`.
1227+
0
1228+
} else {
1229+
v.stride_of(axis) * size as isize
1230+
};
12241231

12251232
let axis = axis.index();
12261233
let mut inner_dim = v.dim.clone();

0 commit comments

Comments
 (0)