Skip to content

Fix axis iterators #669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ where
/// The last view may have less elements if `size` does not divide
/// the axis' dimension.
///
/// **Panics** if `axis` is out of bounds.
/// **Panics** if `axis` is out of bounds or if `size` is zero.
///
/// ```
/// use ndarray::Array;
Expand Down Expand Up @@ -1036,7 +1036,7 @@ where
///
/// Iterator element is `ArrayViewMut<A, D>`
///
/// **Panics** if `axis` is out of bounds.
/// **Panics** if `axis` is out of bounds or if `size` is zero.
pub fn axis_chunks_iter_mut(&mut self, axis: Axis, size: usize) -> AxisChunksIterMut<'_, A, D>
where
S: DataMut,
Expand Down
131 changes: 90 additions & 41 deletions src/iterators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -738,11 +738,19 @@ where

#[derive(Debug)]
pub struct AxisIterCore<A, D> {
/// Index along the axis of the value of `.next()`, relative to the start
/// of the axis.
index: Ix,
len: Ix,
/// (Exclusive) upper bound on `index`. Initially, this is equal to the
/// length of the axis.
end: Ix,
/// Stride along the axis (offset between consecutive pointers).
stride: Ixs,
/// Shape of the iterator's items.
inner_dim: D,
/// Strides of the iterator's items.
inner_strides: D,
/// Pointer corresponding to `index == 0`.
ptr: *mut A,
}

Expand All @@ -751,7 +759,7 @@ clone_bounds!(
AxisIterCore[A, D] {
@copy {
index,
len,
end,
stride,
ptr,
}
Expand All @@ -767,54 +775,53 @@ impl<A, D: Dimension> AxisIterCore<A, D> {
Di: RemoveAxis<Smaller = D>,
S: Data<Elem = A>,
{
let shape = v.shape()[axis.index()];
let stride = v.strides()[axis.index()];
AxisIterCore {
index: 0,
len: shape,
stride,
end: v.len_of(axis),
stride: v.stride_of(axis),
inner_dim: v.dim.remove_axis(axis),
inner_strides: v.strides.remove_axis(axis),
ptr: v.ptr,
}
}

#[inline]
unsafe fn offset(&self, index: usize) -> *mut A {
debug_assert!(
index <= self.len,
"index={}, len={}, stride={}",
index < self.end,
"index={}, end={}, stride={}",
index,
self.len,
self.end,
self.stride
);
self.ptr.offset(index as isize * self.stride)
}

/// Split the iterator at index, yielding two disjoint iterators.
/// Splits the iterator at `index`, yielding two disjoint iterators.
///
/// **Panics** if `index` is strictly greater than the iterator's length
/// `index` is relative to the current state of the iterator (which is not
/// necessarily the start of the axis).
///
/// **Panics** if `index` is strictly greater than the iterator's remaining
/// length.
fn split_at(self, index: usize) -> (Self, Self) {
assert!(index <= self.len);
let right_ptr = if index != self.len {
unsafe { self.offset(index) }
} else {
self.ptr
};
assert!(index <= self.len());
let mid = self.index + index;
let left = AxisIterCore {
index: 0,
len: index,
index: self.index,
end: mid,
stride: self.stride,
inner_dim: self.inner_dim.clone(),
inner_strides: self.inner_strides.clone(),
ptr: self.ptr,
};
let right = AxisIterCore {
index: 0,
len: self.len - index,
index: mid,
end: self.end,
stride: self.stride,
inner_dim: self.inner_dim,
inner_strides: self.inner_strides,
ptr: right_ptr,
ptr: self.ptr,
};
(left, right)
}
Expand All @@ -827,7 +834,7 @@ where
type Item = *mut A;

fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.len {
if self.index >= self.end {
None
} else {
let ptr = unsafe { self.offset(self.index) };
Expand All @@ -837,7 +844,7 @@ where
}

fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.len - self.index;
let len = self.len();
(len, Some(len))
}
}
Expand All @@ -847,16 +854,25 @@ where
D: Dimension,
{
fn next_back(&mut self) -> Option<Self::Item> {
if self.index >= self.len {
if self.index >= self.end {
None
} else {
self.len -= 1;
let ptr = unsafe { self.offset(self.len) };
let ptr = unsafe { self.offset(self.end - 1) };
self.end -= 1;
Some(ptr)
}
}
}

impl<A, D> ExactSizeIterator for AxisIterCore<A, D>
where
D: Dimension,
{
fn len(&self) -> usize {
self.end - self.index
}
}

/// An iterator that traverses over an axis and
/// and yields each subview.
///
Expand Down Expand Up @@ -899,9 +915,13 @@ impl<'a, A, D: Dimension> AxisIter<'a, A, D> {
}
}

/// Split the iterator at index, yielding two disjoint iterators.
/// Splits the iterator at `index`, yielding two disjoint iterators.
///
/// **Panics** if `index` is strictly greater than the iterator's length
/// `index` is relative to the current state of the iterator (which is not
/// necessarily the start of the axis).
///
/// **Panics** if `index` is strictly greater than the iterator's remaining
/// length.
pub fn split_at(self, index: usize) -> (Self, Self) {
let (left, right) = self.iter.split_at(index);
(
Expand Down Expand Up @@ -946,7 +966,7 @@ where
D: Dimension,
{
fn len(&self) -> usize {
self.size_hint().0
self.iter.len()
}
}

Expand Down Expand Up @@ -981,9 +1001,13 @@ impl<'a, A, D: Dimension> AxisIterMut<'a, A, D> {
}
}

/// Split the iterator at index, yielding two disjoint iterators.
/// Splits the iterator at `index`, yielding two disjoint iterators.
///
/// **Panics** if `index` is strictly greater than the iterator's length
/// `index` is relative to the current state of the iterator (which is not
/// necessarily the start of the axis).
///
/// **Panics** if `index` is strictly greater than the iterator's remaining
/// length.
pub fn split_at(self, index: usize) -> (Self, Self) {
let (left, right) = self.iter.split_at(index);
(
Expand Down Expand Up @@ -1028,7 +1052,7 @@ where
D: Dimension,
{
fn len(&self) -> usize {
self.size_hint().0
self.iter.len()
}
}

Expand All @@ -1048,7 +1072,16 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> {
}
#[doc(hidden)]
fn as_ptr(&self) -> Self::Ptr {
self.iter.ptr
if self.len() > 0 {
// `self.iter.index` is guaranteed to be in-bounds if any of the
// iterator remains (i.e. if `self.len() > 0`).
unsafe { self.iter.offset(self.iter.index) }
} else {
// In this case, `self.iter.index` may be past the end, so we must
// not call `.offset()`. It's okay to return a dangling pointer
// because it will never be used in the length 0 case.
std::ptr::NonNull::dangling().as_ptr()
}
}

fn contiguous_stride(&self) -> isize {
Expand All @@ -1065,7 +1098,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> {
}
#[doc(hidden)]
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
self.iter.ptr.offset(self.iter.stride * i[0] as isize)
self.iter.offset(self.iter.index + i[0])
}

#[doc(hidden)]
Expand Down Expand Up @@ -1096,7 +1129,16 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
}
#[doc(hidden)]
fn as_ptr(&self) -> Self::Ptr {
self.iter.ptr
if self.len() > 0 {
// `self.iter.index` is guaranteed to be in-bounds if any of the
// iterator remains (i.e. if `self.len() > 0`).
unsafe { self.iter.offset(self.iter.index) }
} else {
// In this case, `self.iter.index` may be past the end, so we must
// not call `.offset()`. It's okay to return a dangling pointer
// because it will never be used in the length 0 case.
std::ptr::NonNull::dangling().as_ptr()
}
}

fn contiguous_stride(&self) -> isize {
Expand All @@ -1113,7 +1155,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
}
#[doc(hidden)]
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
self.iter.ptr.offset(self.iter.stride * i[0] as isize)
self.iter.offset(self.iter.index + i[0])
}

#[doc(hidden)]
Expand Down Expand Up @@ -1164,21 +1206,28 @@ clone_bounds!(
///
/// Returns an axis iterator with the correct stride to move between chunks,
/// the number of chunks, and the shape of the last chunk.
///
/// **Panics** if `size == 0`.
fn chunk_iter_parts<A, D: Dimension>(
v: ArrayView<'_, A, D>,
axis: Axis,
size: usize,
) -> (AxisIterCore<A, D>, usize, D) {
assert_ne!(size, 0, "Chunk size must be nonzero.");
let axis_len = v.len_of(axis);
let size = if size > axis_len { axis_len } else { size };
let n_whole_chunks = axis_len / size;
let chunk_remainder = axis_len % size;
let iter_len = if chunk_remainder == 0 {
n_whole_chunks
} else {
n_whole_chunks + 1
};
let stride = v.stride_of(axis) * size as isize;
let stride = if n_whole_chunks == 0 {
// This case avoids potential overflow when `size > axis_len`.
0
} else {
v.stride_of(axis) * size as isize
};

let axis = axis.index();
let mut inner_dim = v.dim.clone();
Expand All @@ -1193,7 +1242,7 @@ fn chunk_iter_parts<A, D: Dimension>(

let iter = AxisIterCore {
index: 0,
len: iter_len,
end: iter_len,
stride,
inner_dim,
inner_strides: v.strides,
Expand Down Expand Up @@ -1270,7 +1319,7 @@ macro_rules! chunk_iter_impl {
D: Dimension,
{
fn next_back(&mut self) -> Option<Self::Item> {
let is_uneven = self.iter.len > self.n_whole_chunks;
let is_uneven = self.iter.end > self.n_whole_chunks;
let res = self.iter.next_back();
self.get_subview(res, is_uneven)
}
Expand Down
Loading