Skip to content

Commit f607ff6

Browse files
authored
Merge pull request #639 from nitsky/par-iter-axis-chunks
Parallel Iterator for AxisChunksIter
2 parents f489851 + f9ac9d4 commit f607ff6

File tree

6 files changed

+119
-0
lines changed

6 files changed

+119
-0
lines changed

parallel/src/lib.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,29 @@
6464
//! }
6565
//! ```
6666
//!
67+
//! ## Axis chunks iterators
68+
//!
69+
//! Use the parallel `.axis_chunks_iter()` to process your data in chunks.
70+
//!
71+
//! ```
72+
//! extern crate ndarray;
73+
//!
74+
//! use ndarray::Array;
75+
//! use ndarray::Axis;
76+
//! use ndarray_parallel::prelude::*;
77+
//!
78+
//! fn main() {
79+
//! let a = Array::linspace(0., 63., 64).into_shape((4, 16)).unwrap();
80+
//! let mut shapes = Vec::new();
81+
//! a.axis_chunks_iter(Axis(0), 3)
82+
//! .into_par_iter()
83+
//! .map(|chunk| chunk.shape().to_owned())
84+
//! .collect_into_vec(&mut shapes);
85+
//!
86+
//! assert_eq!(shapes, [vec![3, 16], vec![1, 16]]);
87+
//! }
88+
//! ```
89+
//!
6790
//! ## Zip
6891
//!
6992
//! Use zip for lock step function application across several arrays

parallel/src/par.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use rayon::iter::plumbing::{Consumer, UnindexedConsumer};
88
use rayon::iter::IndexedParallelIterator;
99
use rayon::iter::ParallelIterator;
1010

11+
use ndarray::iter::AxisChunksIter;
12+
use ndarray::iter::AxisChunksIterMut;
1113
use ndarray::iter::AxisIter;
1214
use ndarray::iter::AxisIterMut;
1315
use ndarray::Dimension;
@@ -112,6 +114,8 @@ macro_rules! par_iter_wrapper {
112114

113115
par_iter_wrapper!(AxisIter, [Sync]);
114116
par_iter_wrapper!(AxisIterMut, [Send + Sync]);
117+
par_iter_wrapper!(AxisChunksIter, [Sync]);
118+
par_iter_wrapper!(AxisChunksIterMut, [Send + Sync]);
115119

116120
macro_rules! par_iter_view_wrapper {
117121
// thread_bounds are either Sync or Send + Sync

parallel/tests/rayon.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ use ndarray_parallel::prelude::*;
77

88
const M: usize = 1024 * 10;
99
const N: usize = 100;
10+
const CHUNK_SIZE: usize = 100;
11+
const N_CHUNKS: usize = (M + CHUNK_SIZE - 1) / CHUNK_SIZE;
1012

1113
#[test]
1214
fn test_axis_iter() {
@@ -53,3 +55,32 @@ fn test_regular_iter_collect() {
5355
let v = a.view().into_par_iter().map(|&x| x).collect::<Vec<_>>();
5456
assert_eq!(v.len(), a.len());
5557
}
58+
59+
#[test]
60+
fn test_axis_chunks_iter() {
61+
let mut a = Array2::<f64>::zeros((M, N));
62+
for (i, mut v) in a.axis_chunks_iter_mut(Axis(0), CHUNK_SIZE).enumerate() {
63+
v.fill(i as _);
64+
}
65+
assert_eq!(a.axis_chunks_iter(Axis(0), CHUNK_SIZE).len(), N_CHUNKS);
66+
let s: f64 = a
67+
.axis_chunks_iter(Axis(0), CHUNK_SIZE)
68+
.into_par_iter()
69+
.map(|x| x.sum())
70+
.sum();
71+
println!("{:?}", a.slice(s![..10, ..5]));
72+
assert_eq!(s, a.sum());
73+
}
74+
75+
#[test]
76+
fn test_axis_chunks_iter_mut() {
77+
let mut a = Array::linspace(0., 1.0f64, M * N)
78+
.into_shape((M, N))
79+
.unwrap();
80+
let b = a.mapv(|x| x.exp());
81+
a.axis_chunks_iter_mut(Axis(0), CHUNK_SIZE)
82+
.into_par_iter()
83+
.for_each(|mut v| v.mapv_inplace(|x| x.exp()));
84+
println!("{:?}", a.slice(s![..10, ..5]));
85+
assert!(a.all_close(&b, 0.001));
86+
}

src/parallel/mod.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
//! - [`ArrayView`](ArrayView): `.into_par_iter()`
1515
//! - [`ArrayViewMut`](ArrayViewMut): `.into_par_iter()`
1616
//! - [`AxisIter`](iter::AxisIter), [`AxisIterMut`](iter::AxisIterMut): `.into_par_iter()`
17+
//! - [`AxisChunksIter`](iter::AxisChunksIter), [`AxisChunksIterMut`](iter::AxisChunksIterMut): `.into_par_iter()`
1718
//! - [`Zip`] `.into_par_iter()`
1819
//!
1920
//! The following other parallelized methods exist:
@@ -76,6 +77,29 @@
7677
//! }
7778
//! ```
7879
//!
80+
//! ## Axis chunks iterators
81+
//!
82+
//! Use the parallel `.axis_chunks_iter()` to process your data in chunks.
83+
//!
84+
//! ```
85+
//! extern crate ndarray;
86+
//!
87+
//! use ndarray::Array;
88+
//! use ndarray::Axis;
89+
//! use ndarray::parallel::prelude::*;
90+
//!
91+
//! fn main() {
92+
//! let a = Array::linspace(0., 63., 64).into_shape((4, 16)).unwrap();
93+
//! let mut shapes = Vec::new();
94+
//! a.axis_chunks_iter(Axis(0), 3)
95+
//! .into_par_iter()
96+
//! .map(|chunk| chunk.shape().to_owned())
97+
//! .collect_into_vec(&mut shapes);
98+
//!
99+
//! assert_eq!(shapes, [vec![3, 16], vec![1, 16]]);
100+
//! }
101+
//! ```
102+
//!
79103
//! ## Zip
80104
//!
81105
//! Use zip for lock step function application across several arrays

src/parallel/par.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ use rayon::iter::IndexedParallelIterator;
99
use rayon::iter::ParallelIterator;
1010
use rayon::prelude::IntoParallelIterator;
1111

12+
use crate::iter::AxisChunksIter;
13+
use crate::iter::AxisChunksIterMut;
1214
use crate::iter::AxisIter;
1315
use crate::iter::AxisIterMut;
1416
use crate::Dimension;
@@ -112,6 +114,8 @@ macro_rules! par_iter_wrapper {
112114

113115
par_iter_wrapper!(AxisIter, [Sync]);
114116
par_iter_wrapper!(AxisIterMut, [Send + Sync]);
117+
par_iter_wrapper!(AxisChunksIter, [Sync]);
118+
par_iter_wrapper!(AxisChunksIterMut, [Send + Sync]);
115119

116120
macro_rules! par_iter_view_wrapper {
117121
// thread_bounds are either Sync or Send + Sync

tests/par_rayon.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ use ndarray::prelude::*;
55

66
const M: usize = 1024 * 10;
77
const N: usize = 100;
8+
const CHUNK_SIZE: usize = 100;
9+
const N_CHUNKS: usize = (M + CHUNK_SIZE - 1) / CHUNK_SIZE;
810

911
#[test]
1012
fn test_axis_iter() {
@@ -53,3 +55,34 @@ fn test_regular_iter_collect() {
5355
let v = a.view().into_par_iter().map(|&x| x).collect::<Vec<_>>();
5456
assert_eq!(v.len(), a.len());
5557
}
58+
59+
#[test]
60+
fn test_axis_chunks_iter() {
61+
let mut a = Array2::<f64>::zeros((M, N));
62+
for (i, mut v) in a.axis_chunks_iter_mut(Axis(0), CHUNK_SIZE).enumerate() {
63+
v.fill(i as _);
64+
}
65+
assert_eq!(a.axis_chunks_iter(Axis(0), CHUNK_SIZE).len(), N_CHUNKS);
66+
let s: f64 = a
67+
.axis_chunks_iter(Axis(0), CHUNK_SIZE)
68+
.into_par_iter()
69+
.map(|x| x.sum())
70+
.sum();
71+
println!("{:?}", a.slice(s![..10, ..5]));
72+
assert_eq!(s, a.sum());
73+
}
74+
75+
#[test]
76+
#[cfg(feature = "approx")]
77+
fn test_axis_chunks_iter_mut() {
78+
use approx::assert_abs_diff_eq;
79+
let mut a = Array::linspace(0., 1.0f64, M * N)
80+
.into_shape((M, N))
81+
.unwrap();
82+
let b = a.mapv(|x| x.exp());
83+
a.axis_chunks_iter_mut(Axis(0), CHUNK_SIZE)
84+
.into_par_iter()
85+
.for_each(|mut v| v.mapv_inplace(|x| x.exp()));
86+
println!("{:?}", a.slice(s![..10, ..5]));
87+
assert_abs_diff_eq!(a, b, epsilon = 0.001);
88+
}

0 commit comments

Comments
 (0)