File tree 2 files changed +51
-0
lines changed
2 files changed +51
-0
lines changed Original file line number Diff line number Diff line change @@ -982,6 +982,29 @@ macro_rules! map_impl {
982
982
dimension: self . dimension,
983
983
}
984
984
}
985
+
986
+ /// Apply and collect the results into a new array, which has the same size as the
987
+ /// inputs.
988
+ ///
989
+ /// If all inputs are c- or f-order respectively, that is preserved in the output.
990
+ ///
991
+ /// Restricted to functions that produce copyable results for technical reasons; other
992
+ /// cases are not yet implemented.
993
+ pub fn apply_collect<R >( self , mut f: impl FnMut ( $( $p:: Item , ) * ) -> R ) -> Array <R , D >
994
+ where R : Copy ,
995
+ {
996
+ unsafe {
997
+ let is_c = self . layout. is( CORDER ) ;
998
+ let is_f = !is_c && self . layout. is( FORDER ) ;
999
+ let mut output = Array :: uninitialized( self . dimension. clone( ) . set_f( is_f) ) ;
1000
+ self . and( output. raw_view_mut( ) )
1001
+ . apply( move |$( $p, ) * output_| {
1002
+ std:: ptr:: write( output_, f( $( $p ) ,* ) ) ;
1003
+ } ) ;
1004
+ output
1005
+ }
1006
+ }
1007
+
985
1008
) ;
986
1009
987
1010
/// Split the `Zip` evenly in two.
Original file line number Diff line number Diff line change @@ -49,6 +49,34 @@ fn test_azip2_3() {
49
49
assert ! ( a != b) ;
50
50
}
51
51
52
+ #[ test]
53
+ #[ cfg( feature = "approx" ) ]
54
+ fn test_zip_collect ( ) {
55
+ use approx:: assert_abs_diff_eq;
56
+
57
+ // test Zip::apply_collect and that it preserves c/f layout.
58
+
59
+ let b = Array :: from_shape_fn ( ( 5 , 10 ) , |( i, j) | 1. / ( i + 2 * j + 1 ) as f32 ) ;
60
+ let c = Array :: from_shape_fn ( ( 5 , 10 ) , |( i, j) | f32:: exp ( ( i + j) as f32 ) ) ;
61
+
62
+ {
63
+ let a = Zip :: from ( & b) . and ( & c) . apply_collect ( |x, y| x + y) ;
64
+
65
+ assert_abs_diff_eq ! ( a, & b + & c, epsilon = 1e-6 ) ;
66
+ assert_eq ! ( a. strides( ) , b. strides( ) ) ;
67
+ }
68
+
69
+ {
70
+ let b = b. t ( ) ;
71
+ let c = c. t ( ) ;
72
+
73
+ let a = Zip :: from ( & b) . and ( & c) . apply_collect ( |x, y| x + y) ;
74
+
75
+ assert_abs_diff_eq ! ( a, & b + & c, epsilon = 1e-6 ) ;
76
+ assert_eq ! ( a. strides( ) , b. strides( ) ) ;
77
+ }
78
+ }
79
+
52
80
#[ test]
53
81
fn test_azip_syntax_trailing_comma ( ) {
54
82
let mut b = Array :: < i32 , _ > :: zeros ( ( 5 , 5 ) ) ;
You can’t perform that action at this time.
0 commit comments