1
+ extern crate approx;
1
2
extern crate defmac;
2
3
extern crate ndarray;
3
4
extern crate num_traits;
@@ -8,34 +9,8 @@ use ndarray::linalg::general_mat_mul;
8
9
use ndarray:: linalg:: general_mat_vec_mul;
9
10
use ndarray:: { Ix , Ixs , SliceInfo , SliceOrIndex } ;
10
11
11
- use std :: fmt ;
12
+ use approx :: { assert_abs_diff_eq , assert_relative_eq } ;
12
13
use defmac:: defmac;
13
- use num_traits:: Float ;
14
-
15
- fn assert_approx_eq < F : fmt:: Debug + Float > ( f : F , g : F , tol : F ) -> bool {
16
- assert ! ( ( f - g) . abs( ) <= tol, "{:?} approx== {:?} (tol={:?})" ,
17
- f, g, tol) ;
18
- true
19
- }
20
-
21
- fn assert_close < D > ( a : ArrayView < f64 , D > , b : ArrayView < f64 , D > )
22
- where D : Dimension ,
23
- {
24
- let diff = ( & a - & b) . mapv_into ( f64:: abs) ;
25
-
26
- let rtol = 1e-7 ;
27
- let atol = 1e-12 ;
28
- let crtol = b. mapv ( |x| x. abs ( ) * rtol) ;
29
- let tol = crtol + atol;
30
- let tol_m_diff = & diff - & tol;
31
- let maxdiff = tol_m_diff. fold ( 0. /0. , |x, y| f64:: max ( x, * y) ) ;
32
- println ! ( "diff offset from tolerance level= {:.2e}" , maxdiff) ;
33
- if maxdiff > 0. {
34
- println ! ( "{:.4?}" , a) ;
35
- println ! ( "{:.4?}" , b) ;
36
- panic ! ( "results differ" ) ;
37
- }
38
- }
39
14
40
15
fn reference_dot < ' a , A , V1 , V2 > ( a : V1 , b : V2 ) -> A
41
16
where A : NdFloat ,
@@ -54,32 +29,32 @@ fn dot_product() {
54
29
let a = Array :: range ( 0. , 69. , 1. ) ;
55
30
let b = & a * 2. - 7. ;
56
31
let dot = 197846. ;
57
- assert_approx_eq ( a. dot ( & b) , reference_dot ( & a, & b) , 1e-5 ) ;
32
+ assert_abs_diff_eq ! ( a. dot( & b) , reference_dot( & a, & b) , epsilon = 1e-5 ) ;
58
33
59
34
// test different alignments
60
35
let max = 8 as Ixs ;
61
36
for i in 1 ..max {
62
37
let a1 = a. slice ( s ! [ i..] ) ;
63
38
let b1 = b. slice ( s ! [ i..] ) ;
64
- assert_approx_eq ( a1. dot ( & b1) , reference_dot ( & a1, & b1) , 1e-5 ) ;
39
+ assert_abs_diff_eq ! ( a1. dot( & b1) , reference_dot( & a1, & b1) , epsilon = 1e-5 ) ;
65
40
let a2 = a. slice ( s ! [ ..-i] ) ;
66
41
let b2 = b. slice ( s ! [ i..] ) ;
67
- assert_approx_eq ( a2. dot ( & b2) , reference_dot ( & a2, & b2) , 1e-5 ) ;
42
+ assert_abs_diff_eq ! ( a2. dot( & b2) , reference_dot( & a2, & b2) , epsilon = 1e-5 ) ;
68
43
}
69
44
70
45
71
46
let a = a. map ( |f| * f as f32 ) ;
72
47
let b = b. map ( |f| * f as f32 ) ;
73
- assert_approx_eq ( a. dot ( & b) , dot as f32 , 1e-5 ) ;
48
+ assert_abs_diff_eq ! ( a. dot( & b) , dot as f32 , epsilon = 1e-5 ) ;
74
49
75
50
let max = 8 as Ixs ;
76
51
for i in 1 ..max {
77
52
let a1 = a. slice ( s ! [ i..] ) ;
78
53
let b1 = b. slice ( s ! [ i..] ) ;
79
- assert_approx_eq ( a1. dot ( & b1) , reference_dot ( & a1, & b1) , 1e-5 ) ;
54
+ assert_abs_diff_eq ! ( a1. dot( & b1) , reference_dot( & a1, & b1) , epsilon = 1e-5 ) ;
80
55
let a2 = a. slice ( s ! [ ..-i] ) ;
81
56
let b2 = b. slice ( s ! [ i..] ) ;
82
- assert_approx_eq ( a2. dot ( & b2) , reference_dot ( & a2, & b2) , 1e-5 ) ;
57
+ assert_abs_diff_eq ! ( a2. dot( & b2) , reference_dot( & a2, & b2) , epsilon = 1e-5 ) ;
83
58
}
84
59
85
60
let a = a. map ( |f| * f as i32 ) ;
@@ -94,17 +69,17 @@ fn dot_product_0() {
94
69
let x = 1.5 ;
95
70
let b = aview0 ( & x) ;
96
71
let b = b. broadcast ( a. dim ( ) ) . unwrap ( ) ;
97
- assert_approx_eq ( a. dot ( & b) , reference_dot ( & a, & b) , 1e-5 ) ;
72
+ assert_abs_diff_eq ! ( a. dot( & b) , reference_dot( & a, & b) , epsilon = 1e-5 ) ;
98
73
99
74
// test different alignments
100
75
let max = 8 as Ixs ;
101
76
for i in 1 ..max {
102
77
let a1 = a. slice ( s ! [ i..] ) ;
103
78
let b1 = b. slice ( s ! [ i..] ) ;
104
- assert_approx_eq ( a1. dot ( & b1) , reference_dot ( & a1, & b1) , 1e-5 ) ;
79
+ assert_abs_diff_eq ! ( a1. dot( & b1) , reference_dot( & a1, & b1) , epsilon = 1e-5 ) ;
105
80
let a2 = a. slice ( s ! [ ..-i] ) ;
106
81
let b2 = b. slice ( s ! [ i..] ) ;
107
- assert_approx_eq ( a2. dot ( & b2) , reference_dot ( & a2, & b2) , 1e-5 ) ;
82
+ assert_abs_diff_eq ! ( a2. dot( & b2) , reference_dot( & a2, & b2) , epsilon = 1e-5 ) ;
108
83
}
109
84
}
110
85
@@ -117,13 +92,13 @@ fn dot_product_neg_stride() {
117
92
// both negative
118
93
let a = a. slice ( s ! [ ..; stride] ) ;
119
94
let b = b. slice ( s ! [ ..; stride] ) ;
120
- assert_approx_eq ( a. dot ( & b) , reference_dot ( & a, & b) , 1e-5 ) ;
95
+ assert_abs_diff_eq ! ( a. dot( & b) , reference_dot( & a, & b) , epsilon = 1e-5 ) ;
121
96
}
122
97
for stride in -10 ..0 {
123
98
// mixed
124
99
let a = a. slice ( s ! [ ..; -stride] ) ;
125
100
let b = b. slice ( s ! [ ..; stride] ) ;
126
- assert_approx_eq ( a. dot ( & b) , reference_dot ( & a, & b) , 1e-5 ) ;
101
+ assert_abs_diff_eq ! ( a. dot( & b) , reference_dot( & a, & b) , epsilon = 1e-5 ) ;
127
102
}
128
103
}
129
104
@@ -402,7 +377,7 @@ fn scaled_add_2() {
402
377
answerv += & ( beta * & c) ;
403
378
av. scaled_add ( beta, & c) ;
404
379
}
405
- assert_close ( a . view ( ) , answer. view ( ) ) ;
380
+ assert_relative_eq ! ( a , answer, epsilon = 1e-12 , max_relative = 1e-7 ) ;
406
381
}
407
382
}
408
383
}
@@ -451,7 +426,7 @@ fn scaled_add_3() {
451
426
answerv += & ( beta * & c) ;
452
427
av. scaled_add ( beta, & c) ;
453
428
}
454
- assert_close ( a . view ( ) , answer. view ( ) ) ;
429
+ assert_relative_eq ! ( a , answer, epsilon = 1e-12 , max_relative = 1e-7 ) ;
455
430
}
456
431
}
457
432
}
@@ -490,7 +465,7 @@ fn gen_mat_mul() {
490
465
491
466
general_mat_mul ( alpha, & a, & b, beta, & mut cv) ;
492
467
}
493
- assert_close ( c . view ( ) , answer. view ( ) ) ;
468
+ assert_relative_eq ! ( c , answer, epsilon = 1e-12 , max_relative = 1e-7 ) ;
494
469
}
495
470
}
496
471
}
@@ -507,7 +482,7 @@ fn gemm_64_1_f() {
507
482
let mut y = range_mat64 ( m, 1 ) ;
508
483
let answer = reference_mat_mul ( & a, & x) + & y;
509
484
general_mat_mul ( 1.0 , & a, & x, 1.0 , & mut y) ;
510
- assert_close ( y . view ( ) , answer. view ( ) ) ;
485
+ assert_relative_eq ! ( y , answer, epsilon = 1e-12 , max_relative = 1e-7 ) ;
511
486
}
512
487
513
488
#[ test]
@@ -572,7 +547,7 @@ fn gen_mat_vec_mul() {
572
547
573
548
general_mat_vec_mul ( alpha, & a, & b, beta, & mut cv) ;
574
549
}
575
- assert_close ( c . view ( ) , answer. view ( ) ) ;
550
+ assert_relative_eq ! ( c , answer, epsilon = 1e-12 , max_relative = 1e-7 ) ;
576
551
}
577
552
}
578
553
}
@@ -614,7 +589,7 @@ fn vec_mat_mul() {
614
589
615
590
c. slice_mut ( s ! [ ..; s2] ) . assign ( & a. dot ( & b) ) ;
616
591
}
617
- assert_close ( c . view ( ) , answer. view ( ) ) ;
592
+ assert_relative_eq ! ( c , answer, epsilon = 1e-12 , max_relative = 1e-7 ) ;
618
593
}
619
594
}
620
595
}
0 commit comments