Skip to content

Commit 2ab90e9

Browse files
committed
FEAT: Add method Zip::apply_collect to collect results into Array
Use uninitialized to not need to zero or clone result elements; Restrict to Copy to use this constructor (and avoid panic safety and drop on panic issues for now).
1 parent 25cf334 commit 2ab90e9

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

src/zip/mod.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,29 @@ macro_rules! map_impl {
982982
dimension: self.dimension,
983983
}
984984
}
985+
986+
/// Apply and collect the results into a new array, which has the same size as the
987+
/// inputs.
988+
///
989+
/// If all inputs are c- or f-order respectively, that is preserved in the output.
990+
///
991+
/// Restricted to functions that produce copyable results for technical reasons; other
992+
/// cases are not yet implemented.
993+
pub fn apply_collect<R>(self, mut f: impl FnMut($($p::Item,)* ) -> R) -> Array<R, D>
994+
where R: Copy,
995+
{
996+
unsafe {
997+
let is_c = self.layout.is(CORDER);
998+
let is_f = !is_c && self.layout.is(FORDER);
999+
let mut output = Array::uninitialized(self.dimension.clone().set_f(is_f));
1000+
self.and(output.raw_view_mut())
1001+
.apply(move |$($p, )* output_| {
1002+
std::ptr::write(output_, f($($p ),*));
1003+
});
1004+
output
1005+
}
1006+
}
1007+
9851008
);
9861009

9871010
/// Split the `Zip` evenly in two.

tests/azip.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,34 @@ fn test_azip2_3() {
4949
assert!(a != b);
5050
}
5151

52+
#[test]
53+
#[cfg(feature = "approx")]
54+
fn test_zip_collect() {
55+
use approx::assert_abs_diff_eq;
56+
57+
// test Zip::apply_collect and that it preserves c/f layout.
58+
59+
let b = Array::from_shape_fn((5, 10), |(i, j)| 1. / (i + 2 * j + 1) as f32);
60+
let c = Array::from_shape_fn((5, 10), |(i, j)| f32::exp((i + j) as f32));
61+
62+
{
63+
let a = Zip::from(&b).and(&c).apply_collect(|x, y| x + y);
64+
65+
assert_abs_diff_eq!(a, &b + &c, epsilon = 1e-6);
66+
assert_eq!(a.strides(), b.strides());
67+
}
68+
69+
{
70+
let b = b.t();
71+
let c = c.t();
72+
73+
let a = Zip::from(&b).and(&c).apply_collect(|x, y| x + y);
74+
75+
assert_abs_diff_eq!(a, &b + &c, epsilon = 1e-6);
76+
assert_eq!(a.strides(), b.strides());
77+
}
78+
}
79+
5280
#[test]
5381
fn test_azip_syntax_trailing_comma() {
5482
let mut b = Array::<i32, _>::zeros((5, 5));

0 commit comments

Comments
 (0)