Skip to content

More ergonomic slicing #252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions Guidelines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,16 @@ using ``rustfmt`` on the latest stable Rust channel:
cargo fmt --all

.. _rustfmt: https://github.com/rust-lang-nursery/rustfmt

Reborrowing
===========

Some methods return views of their containers, eg ``CsMatBase::slice_outer``
returns a ``CsMatViewI``. However, in certain situations, mostly when
implementing iterators, we are calling these kind of methods on a view, and
need to take the lifetime of the view, not the lifetime of ``self`` in the
method call. To deal with this issue, the method should in fact be implemented
on the view type (on ``CsMatViewI`` in the example), with a ``_rbr`` suffix (
``CsMatViewI::slice_outer_rbr`` in the example), and the implementation on the
base type should simply call the view version (in the example, it should call
``self.view().slice_outer_rbr(range)``).
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ pub mod errors;
pub mod indexing;
pub mod io;
pub mod num_kinds;
mod range;
mod sparse;
pub mod stack;

Expand Down
70 changes: 70 additions & 0 deletions src/range.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
//! This module abstracts over ranges to allow functions in the crate to take
//! ranges as input in an ergonomic way. It howvers seals this abstractions to
//! leave full control in this crate.

/// Abstract over `std::ops::{Range,RangeFrom,RangeTo,RangeFull}`
pub trait Range {
fn start(&self) -> usize;

fn end(&self) -> Option<usize>;
}

impl Range for std::ops::Range<usize> {
fn start(&self) -> usize {
self.start
}

fn end(&self) -> Option<usize> {
Some(self.end)
}
}

impl Range for std::ops::RangeFrom<usize> {
fn start(&self) -> usize {
self.start
}

fn end(&self) -> Option<usize> {
None
}
}

impl Range for std::ops::RangeTo<usize> {
fn start(&self) -> usize {
0
}

fn end(&self) -> Option<usize> {
Some(self.end)
}
}

impl Range for std::ops::RangeFull {
fn start(&self) -> usize {
0
}

fn end(&self) -> Option<usize> {
None
}
}

impl Range for std::ops::RangeInclusive<usize> {
fn start(&self) -> usize {
*self.start()
}

fn end(&self) -> Option<usize> {
Some(*self.end() + 1)
}
}

impl Range for std::ops::RangeToInclusive<usize> {
fn start(&self) -> usize {
0
}

fn end(&self) -> Option<usize> {
Some(self.end + 1)
}
}
1 change: 1 addition & 0 deletions src/sparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ pub mod kronecker;
pub mod linalg;
pub mod permutation;
pub mod prod;
pub mod slicing;
pub mod smmp;
pub mod special_mats;
pub mod symmetric;
Expand Down
27 changes: 25 additions & 2 deletions src/sparse/csmat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,15 @@ impl<'a, N: 'a, I: 'a + SpIndex, Iptr: 'a + SpIndex>
/// Get a view into count contiguous outer dimensions, starting from i.
///
/// eg this gets the rows from i to i + count in a CSR matrix
///
/// This function is now deprecated, as using an index and a count is not
/// ergonomic. The replacement, `slice_outer`, leverages the
/// `std::ops::Range` family of types, which is better integrated into the
/// ecosystem.
#[deprecated(
since = "0.10.0",
note = "Please use the `slice_outer` method instead"
)]
pub fn middle_outer_views(
&self,
i: usize,
Expand All @@ -682,7 +691,7 @@ impl<'a, N: 'a, I: 'a + SpIndex, Iptr: 'a + SpIndex>
storage: self.storage,
nrows,
ncols,
indptr: self.indptr.middle_slice(i, iend),
indptr: self.indptr.middle_slice_rbr(i..iend),
indices: &self.indices[data_range.clone()],
data: &self.data[data_range],
}
Expand Down Expand Up @@ -1185,7 +1194,7 @@ where
} else {
block_size
};
self.view().middle_outer_views(i, count)
self.view().slice_outer_rbr(i..i + count)
})
}

Expand Down Expand Up @@ -1518,6 +1527,18 @@ where
}
})
}

/// Return a mutable view into the current matrix
pub fn view_mut(&mut self) -> CsMatViewMutI<N, I, Iptr> {
CsMatViewMutI {
storage: self.storage,
nrows: self.nrows,
ncols: self.ncols,
indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
indices: &self.indices[..],
data: &mut self.data[..],
}
}
}

impl<N, I, Iptr, IptrStorage, IndStorage, DataStorage>
Expand Down Expand Up @@ -2544,11 +2565,13 @@ mod test {
fn middle_outer_views() {
let size = 11;
let csr: CsMat<f64> = CsMat::eye(size);
#[allow(deprecated)]
let v = csr.view().middle_outer_views(1, 3);
assert_eq!(v.shape(), (3, size));
assert_eq!(v.nnz(), 3);

let csc = csr.to_other_storage();
#[allow(deprecated)]
let v = csc.view().middle_outer_views(1, 3);
assert_eq!(v.shape(), (size, 3));
assert_eq!(v.nnz(), 3);
Expand Down
22 changes: 19 additions & 3 deletions src/sparse/indptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ where
IndPtrBase { storage }
}

pub fn view(&self) -> IndPtrView<Iptr> {
IndPtrView {
storage: &self.storage[..],
}
}

/// The length of the underlying storage
pub fn len(&self) -> usize {
self.storage.len()
Expand Down Expand Up @@ -348,6 +354,15 @@ where
// larger than the first, and that both can be represented as an usize
self.storage.last().map(|i| *i - offset).unwrap_or(zero)
}

/// Slice this indptr to include only the outer dimensions in the range
/// `start..end`.
pub(crate) fn middle_slice(
&self,
range: impl crate::range::Range,
) -> IndPtrView<Iptr> {
self.view().middle_slice_rbr(range)
}
}

impl<Iptr: SpIndex> IndPtr<Iptr> {
Expand Down Expand Up @@ -387,11 +402,12 @@ impl<'a, Iptr: SpIndex> IndPtrView<'a, Iptr> {
/// Slice this indptr to include only the outer dimensions in the range
/// `start..end`. Reborrows to get the actual lifetime of the data wrapped
/// in this view
pub(crate) fn middle_slice(
pub(crate) fn middle_slice_rbr(
&self,
start: usize,
end: usize,
range: impl crate::range::Range,
) -> IndPtrView<'a, Iptr> {
let start = range.start();
let end = range.end().unwrap_or_else(|| self.outer_dims());
IndPtrView {
storage: &self.storage[start..=end],
}
Expand Down
108 changes: 108 additions & 0 deletions src/sparse/slicing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
//! This module implementations to slice a matrix along the desired dimension.
//! We're using a sealed trait to enable using ranges for an idiomatic API.

use crate::range::Range;
use crate::{CsMatBase, CsMatViewI, CsMatViewMutI, SpIndex};
use std::ops::{Deref, DerefMut};

impl<N, I: SpIndex, Iptr: SpIndex, IptrStorage, IStorage, DStorage>
CsMatBase<N, I, IptrStorage, IStorage, DStorage, Iptr>
where
IptrStorage: Deref<Target = [Iptr]>,
IStorage: Deref<Target = [I]>,
DStorage: Deref<Target = [N]>,
{
/// Slice the outer dimension of the matrix according to the specified
/// range.
pub fn slice_outer<S: Range>(&self, range: S) -> CsMatViewI<N, I, Iptr> {
self.view().slice_outer_rbr(range)
}
}

impl<N, I: SpIndex, Iptr: SpIndex, IptrStorage, IStorage, DStorage>
CsMatBase<N, I, IptrStorage, IStorage, DStorage, Iptr>
where
IptrStorage: Deref<Target = [Iptr]>,
IStorage: Deref<Target = [I]>,
DStorage: DerefMut<Target = [N]>,
{
/// Slice the outer dimension of the matrix according to the specified
/// range.
pub fn slice_outer_mut<S: Range>(
&mut self,
range: S,
) -> CsMatViewMutI<N, I, Iptr> {
let start = range.start();
let end = range.end().unwrap_or_else(|| self.outer_dims());
if end < start {
panic!("Invalid view");
}
let outer_inds_slice = self.indptr.outer_inds_slice(start, end);
let (nrows, ncols) = match self.storage() {
crate::CSR => ((end - start), self.ncols),
crate::CSC => (self.nrows, (end - start)),
};
CsMatViewMutI {
nrows,
ncols,
storage: self.storage,
indptr: self.indptr.middle_slice(range),
indices: &self.indices[outer_inds_slice.clone()],
data: &mut self.data[outer_inds_slice],
}
}
}

impl<'a, N, I, Iptr> crate::CsMatViewI<'a, N, I, Iptr>
where
I: crate::SpIndex,
Iptr: crate::SpIndex,
{
/// Slice the outer dimension of the matrix according to the specified
/// range.
pub fn slice_outer_rbr<S>(
&self,
range: S,
) -> crate::CsMatViewI<'a, N, I, Iptr>
where
S: Range,
{
let start = range.start();
let end = range.end().unwrap_or_else(|| self.outer_dims());
if end < start {
panic!("Invalid view");
}
let outer_inds_slice = self.indptr.outer_inds_slice(start, end);
let (nrows, ncols) = match self.storage() {
crate::CSR => ((end - start), self.ncols),
crate::CSC => (self.nrows, (end - start)),
};
crate::CsMatViewI {
nrows,
ncols,
storage: self.storage,
indptr: self.indptr.middle_slice_rbr(range),
indices: &self.indices[outer_inds_slice.clone()],
data: &self.data[outer_inds_slice],
}
}
}

#[cfg(test)]
mod tests {
use crate::CsMat;

#[test]
fn slice_outer() {
let size = 11;
let csr: CsMat<f64> = CsMat::eye(size);
let sliced = csr.slice_outer(2..7);
let mut iter = sliced.into_iter();
assert_eq!(iter.next().unwrap(), (&1., (0, 2)));
assert_eq!(iter.next().unwrap(), (&1., (1, 3)));
assert_eq!(iter.next().unwrap(), (&1., (2, 4)));
assert_eq!(iter.next().unwrap(), (&1., (3, 5)));
assert_eq!(iter.next().unwrap(), (&1., (4, 6)));
assert!(iter.next().is_none());
}
}
7 changes: 3 additions & 4 deletions src/sparse/smmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ where
} else {
l_rows
};
lhs_chunks.push(lhs.middle_outer_views(start, stop - start));
lhs_chunks.push(lhs.slice_outer(start..stop));
res_indptr_chunks.push(vec![Iptr::zero(); stop - start + 1]);
res_indices_chunks
.push(Vec::with_capacity(lhs.nnz() + rhs.nnz() / chunk_size));
Expand Down Expand Up @@ -366,8 +366,7 @@ where
for (row, nnz) in res_indptr.iter().enumerate() {
let nnz = nnz.index();
if nnz - split_nnz > chunk_size && row > 0 {
lhs_chunks
.push(lhs.middle_outer_views(split_row, row - 1 - split_row));
lhs_chunks.push(lhs.slice_outer(split_row..row - 1));

res_indptr_chunks.push(&res_indptr[split_row..row]);

Expand All @@ -386,7 +385,7 @@ where
}
prev_nnz = nnz;
}
lhs_chunks.push(lhs.middle_outer_views(split_row, lhs.rows() - split_row));
lhs_chunks.push(lhs.slice_outer(split_row..lhs.rows()));
res_indptr_chunks.push(&res_indptr[split_row..]);
res_indices_chunks.push(res_indices_rem);
res_data_chunks.push(res_data_rem);
Expand Down
Loading