diff --git a/Guidelines.rst b/Guidelines.rst index 4ca9d2fd..cbeff3b7 100644 --- a/Guidelines.rst +++ b/Guidelines.rst @@ -40,3 +40,16 @@ using ``rustfmt`` on the latest stable Rust channel: cargo fmt --all .. _rustfmt: https://github.com/rust-lang-nursery/rustfmt + +Reborrowing +=========== + +Some methods return views of their containers, eg ``CsMatBase::slice_outer`` +returns a ``CsMatViewI``. However, in certain situations, mostly when +implementing iterators, we are calling these kind of methods on a view, and +need to take the lifetime of the view, not the lifetime of ``self`` in the +method call. To deal with this issue, the method should in fact be implemented +on the view type (on ``CsMatViewI`` in the example), with a ``_rbr`` suffix ( +``CsMatViewI::slice_outer_rbr`` in the example), and the implementation on the +base type should simply call the view version (in the example, it should call +``self.view().slice_outer_rbr(range)``). diff --git a/src/lib.rs b/src/lib.rs index 0bdc76fa..70f2c1f2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -79,6 +79,7 @@ pub mod errors; pub mod indexing; pub mod io; pub mod num_kinds; +mod range; mod sparse; pub mod stack; diff --git a/src/range.rs b/src/range.rs new file mode 100644 index 00000000..9e90cb90 --- /dev/null +++ b/src/range.rs @@ -0,0 +1,70 @@ +//! This module abstracts over ranges to allow functions in the crate to take +//! ranges as input in an ergonomic way. It howvers seals this abstractions to +//! leave full control in this crate. + +/// Abstract over `std::ops::{Range,RangeFrom,RangeTo,RangeFull}` +pub trait Range { + fn start(&self) -> usize; + + fn end(&self) -> Option; +} + +impl Range for std::ops::Range { + fn start(&self) -> usize { + self.start + } + + fn end(&self) -> Option { + Some(self.end) + } +} + +impl Range for std::ops::RangeFrom { + fn start(&self) -> usize { + self.start + } + + fn end(&self) -> Option { + None + } +} + +impl Range for std::ops::RangeTo { + fn start(&self) -> usize { + 0 + } + + fn end(&self) -> Option { + Some(self.end) + } +} + +impl Range for std::ops::RangeFull { + fn start(&self) -> usize { + 0 + } + + fn end(&self) -> Option { + None + } +} + +impl Range for std::ops::RangeInclusive { + fn start(&self) -> usize { + *self.start() + } + + fn end(&self) -> Option { + Some(*self.end() + 1) + } +} + +impl Range for std::ops::RangeToInclusive { + fn start(&self) -> usize { + 0 + } + + fn end(&self) -> Option { + Some(self.end + 1) + } +} diff --git a/src/sparse.rs b/src/sparse.rs index c3fb1332..eb6fc3ba 100644 --- a/src/sparse.rs +++ b/src/sparse.rs @@ -418,6 +418,7 @@ pub mod kronecker; pub mod linalg; pub mod permutation; pub mod prod; +pub mod slicing; pub mod smmp; pub mod special_mats; pub mod symmetric; diff --git a/src/sparse/csmat.rs b/src/sparse/csmat.rs index 97dce25a..1bc3dbe6 100644 --- a/src/sparse/csmat.rs +++ b/src/sparse/csmat.rs @@ -667,6 +667,15 @@ impl<'a, N: 'a, I: 'a + SpIndex, Iptr: 'a + SpIndex> /// Get a view into count contiguous outer dimensions, starting from i. /// /// eg this gets the rows from i to i + count in a CSR matrix + /// + /// This function is now deprecated, as using an index and a count is not + /// ergonomic. The replacement, `slice_outer`, leverages the + /// `std::ops::Range` family of types, which is better integrated into the + /// ecosystem. + #[deprecated( + since = "0.10.0", + note = "Please use the `slice_outer` method instead" + )] pub fn middle_outer_views( &self, i: usize, @@ -682,7 +691,7 @@ impl<'a, N: 'a, I: 'a + SpIndex, Iptr: 'a + SpIndex> storage: self.storage, nrows, ncols, - indptr: self.indptr.middle_slice(i, iend), + indptr: self.indptr.middle_slice_rbr(i..iend), indices: &self.indices[data_range.clone()], data: &self.data[data_range], } @@ -1185,7 +1194,7 @@ where } else { block_size }; - self.view().middle_outer_views(i, count) + self.view().slice_outer_rbr(i..i + count) }) } @@ -1518,6 +1527,18 @@ where } }) } + + /// Return a mutable view into the current matrix + pub fn view_mut(&mut self) -> CsMatViewMutI { + CsMatViewMutI { + storage: self.storage, + nrows: self.nrows, + ncols: self.ncols, + indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()), + indices: &self.indices[..], + data: &mut self.data[..], + } + } } impl @@ -2544,11 +2565,13 @@ mod test { fn middle_outer_views() { let size = 11; let csr: CsMat = CsMat::eye(size); + #[allow(deprecated)] let v = csr.view().middle_outer_views(1, 3); assert_eq!(v.shape(), (3, size)); assert_eq!(v.nnz(), 3); let csc = csr.to_other_storage(); + #[allow(deprecated)] let v = csc.view().middle_outer_views(1, 3); assert_eq!(v.shape(), (size, 3)); assert_eq!(v.nnz(), 3); diff --git a/src/sparse/indptr.rs b/src/sparse/indptr.rs index 7c8340d3..f29093dd 100644 --- a/src/sparse/indptr.rs +++ b/src/sparse/indptr.rs @@ -82,6 +82,12 @@ where IndPtrBase { storage } } + pub fn view(&self) -> IndPtrView { + IndPtrView { + storage: &self.storage[..], + } + } + /// The length of the underlying storage pub fn len(&self) -> usize { self.storage.len() @@ -348,6 +354,15 @@ where // larger than the first, and that both can be represented as an usize self.storage.last().map(|i| *i - offset).unwrap_or(zero) } + + /// Slice this indptr to include only the outer dimensions in the range + /// `start..end`. + pub(crate) fn middle_slice( + &self, + range: impl crate::range::Range, + ) -> IndPtrView { + self.view().middle_slice_rbr(range) + } } impl IndPtr { @@ -387,11 +402,12 @@ impl<'a, Iptr: SpIndex> IndPtrView<'a, Iptr> { /// Slice this indptr to include only the outer dimensions in the range /// `start..end`. Reborrows to get the actual lifetime of the data wrapped /// in this view - pub(crate) fn middle_slice( + pub(crate) fn middle_slice_rbr( &self, - start: usize, - end: usize, + range: impl crate::range::Range, ) -> IndPtrView<'a, Iptr> { + let start = range.start(); + let end = range.end().unwrap_or_else(|| self.outer_dims()); IndPtrView { storage: &self.storage[start..=end], } diff --git a/src/sparse/slicing.rs b/src/sparse/slicing.rs new file mode 100644 index 00000000..7484134b --- /dev/null +++ b/src/sparse/slicing.rs @@ -0,0 +1,108 @@ +//! This module implementations to slice a matrix along the desired dimension. +//! We're using a sealed trait to enable using ranges for an idiomatic API. + +use crate::range::Range; +use crate::{CsMatBase, CsMatViewI, CsMatViewMutI, SpIndex}; +use std::ops::{Deref, DerefMut}; + +impl + CsMatBase +where + IptrStorage: Deref, + IStorage: Deref, + DStorage: Deref, +{ + /// Slice the outer dimension of the matrix according to the specified + /// range. + pub fn slice_outer(&self, range: S) -> CsMatViewI { + self.view().slice_outer_rbr(range) + } +} + +impl + CsMatBase +where + IptrStorage: Deref, + IStorage: Deref, + DStorage: DerefMut, +{ + /// Slice the outer dimension of the matrix according to the specified + /// range. + pub fn slice_outer_mut( + &mut self, + range: S, + ) -> CsMatViewMutI { + let start = range.start(); + let end = range.end().unwrap_or_else(|| self.outer_dims()); + if end < start { + panic!("Invalid view"); + } + let outer_inds_slice = self.indptr.outer_inds_slice(start, end); + let (nrows, ncols) = match self.storage() { + crate::CSR => ((end - start), self.ncols), + crate::CSC => (self.nrows, (end - start)), + }; + CsMatViewMutI { + nrows, + ncols, + storage: self.storage, + indptr: self.indptr.middle_slice(range), + indices: &self.indices[outer_inds_slice.clone()], + data: &mut self.data[outer_inds_slice], + } + } +} + +impl<'a, N, I, Iptr> crate::CsMatViewI<'a, N, I, Iptr> +where + I: crate::SpIndex, + Iptr: crate::SpIndex, +{ + /// Slice the outer dimension of the matrix according to the specified + /// range. + pub fn slice_outer_rbr( + &self, + range: S, + ) -> crate::CsMatViewI<'a, N, I, Iptr> + where + S: Range, + { + let start = range.start(); + let end = range.end().unwrap_or_else(|| self.outer_dims()); + if end < start { + panic!("Invalid view"); + } + let outer_inds_slice = self.indptr.outer_inds_slice(start, end); + let (nrows, ncols) = match self.storage() { + crate::CSR => ((end - start), self.ncols), + crate::CSC => (self.nrows, (end - start)), + }; + crate::CsMatViewI { + nrows, + ncols, + storage: self.storage, + indptr: self.indptr.middle_slice_rbr(range), + indices: &self.indices[outer_inds_slice.clone()], + data: &self.data[outer_inds_slice], + } + } +} + +#[cfg(test)] +mod tests { + use crate::CsMat; + + #[test] + fn slice_outer() { + let size = 11; + let csr: CsMat = CsMat::eye(size); + let sliced = csr.slice_outer(2..7); + let mut iter = sliced.into_iter(); + assert_eq!(iter.next().unwrap(), (&1., (0, 2))); + assert_eq!(iter.next().unwrap(), (&1., (1, 3))); + assert_eq!(iter.next().unwrap(), (&1., (2, 4))); + assert_eq!(iter.next().unwrap(), (&1., (3, 5))); + assert_eq!(iter.next().unwrap(), (&1., (4, 6))); + assert!(iter.next().is_none()); + } +} diff --git a/src/sparse/smmp.rs b/src/sparse/smmp.rs index 86bcc97d..4a9e47b3 100644 --- a/src/sparse/smmp.rs +++ b/src/sparse/smmp.rs @@ -304,7 +304,7 @@ where } else { l_rows }; - lhs_chunks.push(lhs.middle_outer_views(start, stop - start)); + lhs_chunks.push(lhs.slice_outer(start..stop)); res_indptr_chunks.push(vec![Iptr::zero(); stop - start + 1]); res_indices_chunks .push(Vec::with_capacity(lhs.nnz() + rhs.nnz() / chunk_size)); @@ -366,8 +366,7 @@ where for (row, nnz) in res_indptr.iter().enumerate() { let nnz = nnz.index(); if nnz - split_nnz > chunk_size && row > 0 { - lhs_chunks - .push(lhs.middle_outer_views(split_row, row - 1 - split_row)); + lhs_chunks.push(lhs.slice_outer(split_row..row - 1)); res_indptr_chunks.push(&res_indptr[split_row..row]); @@ -386,7 +385,7 @@ where } prev_nnz = nnz; } - lhs_chunks.push(lhs.middle_outer_views(split_row, lhs.rows() - split_row)); + lhs_chunks.push(lhs.slice_outer(split_row..lhs.rows())); res_indptr_chunks.push(&res_indptr[split_row..]); res_indices_chunks.push(res_indices_rem); res_data_chunks.push(res_data_rem); diff --git a/tests/slicing.rs b/tests/slicing.rs new file mode 100644 index 00000000..8a9bb8ab --- /dev/null +++ b/tests/slicing.rs @@ -0,0 +1,71 @@ +//! Test slicing from outside the crate to ensure the sealed trait +//! for ranges is effective + +#[test] +fn slice_outer() { + use sprs::CsMat; + let size = 11; + let csr: CsMat = CsMat::eye(size); + let sliced = csr.slice_outer(2..7); + let mut iter = sliced.into_iter(); + assert_eq!(iter.next().unwrap(), (&1., (0, 2))); + assert_eq!(iter.next().unwrap(), (&1., (1, 3))); + assert_eq!(iter.next().unwrap(), (&1., (2, 4))); + assert_eq!(iter.next().unwrap(), (&1., (3, 5))); + assert_eq!(iter.next().unwrap(), (&1., (4, 6))); + assert!(iter.next().is_none()); +} + +#[test] +fn slice_outer_mut() { + use sprs::CsMat; + let size = 11; + let mut csr: CsMat = CsMat::eye(size); + let mut sliced = csr.slice_outer_mut(2..7); + sliced.scale(2.); + let mut iter = sliced.into_iter(); + assert_eq!(iter.next().unwrap(), (&2., (0, 2))); + assert_eq!(iter.next().unwrap(), (&2., (1, 3))); + assert_eq!(iter.next().unwrap(), (&2., (2, 4))); + assert_eq!(iter.next().unwrap(), (&2., (3, 5))); + assert_eq!(iter.next().unwrap(), (&2., (4, 6))); + assert!(iter.next().is_none()); + + let mut iter = csr.into_iter(); + assert_eq!(iter.next().unwrap(), (&1., (0, 0))); + assert_eq!(iter.next().unwrap(), (&1., (1, 1))); + assert_eq!(iter.next().unwrap(), (&2., (2, 2))); + assert_eq!(iter.next().unwrap(), (&2., (3, 3))); + assert_eq!(iter.next().unwrap(), (&2., (4, 4))); + assert_eq!(iter.next().unwrap(), (&2., (5, 5))); + assert_eq!(iter.next().unwrap(), (&2., (6, 6))); + assert_eq!(iter.next().unwrap(), (&1., (7, 7))); + assert_eq!(iter.next().unwrap(), (&1., (8, 8))); + assert_eq!(iter.next().unwrap(), (&1., (9, 9))); + assert_eq!(iter.next().unwrap(), (&1., (10, 10))); + assert!(iter.next().is_none()); +} + +#[test] +fn slice_outer_other_ranges() { + use sprs::CsMat; + let size = 11; + let csr: CsMat = CsMat::eye(size); + let sliced = csr.slice_outer(..5); + let mut iter = sliced.into_iter(); + assert_eq!(iter.next().unwrap(), (&1., (0, 0))); + assert_eq!(iter.next().unwrap(), (&1., (1, 1))); + assert_eq!(iter.next().unwrap(), (&1., (2, 2))); + assert_eq!(iter.next().unwrap(), (&1., (3, 3))); + assert_eq!(iter.next().unwrap(), (&1., (4, 4))); + assert!(iter.next().is_none()); + + let sliced = csr.slice_outer(9..); + let mut iter = sliced.into_iter(); + assert_eq!(iter.next().unwrap(), (&1., (0, 9))); + assert_eq!(iter.next().unwrap(), (&1., (1, 10))); + assert!(iter.next().is_none()); + + let sliced = csr.slice_outer(..); + assert_eq!(sliced, csr.view()); +}