Skip to content

Commit 6b56adf

Browse files
authored
Use approx for approx. comparisons in tests (#631)
This allows us to get rid of the custom comparison functions.
1 parent 97def32 commit 6b56adf

File tree

5 files changed

+109
-238
lines changed

5 files changed

+109
-238
lines changed

blas-tests/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ publish = false
88
test = false
99

1010
[dev-dependencies]
11-
ndarray = { path = "../", features = ["blas"] }
11+
approx = "0.3.2"
12+
ndarray = { path = "../", features = ["approx", "blas"] }
1213
blas-src = { version = "0.2.0", default-features = false, features = ["openblas"] }
1314
openblas-src = { version = "0.6.0", default-features = false, features = ["cblas", "system"] }
1415
defmac = "0.2"
1516
num-traits = "0.2"
16-

blas-tests/tests/oper.rs

Lines changed: 19 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
extern crate approx;
12
extern crate defmac;
23
extern crate ndarray;
34
extern crate num_traits;
@@ -8,34 +9,8 @@ use ndarray::linalg::general_mat_mul;
89
use ndarray::linalg::general_mat_vec_mul;
910
use ndarray::{Ix, Ixs, SliceInfo, SliceOrIndex};
1011

11-
use std::fmt;
12+
use approx::{assert_abs_diff_eq, assert_relative_eq};
1213
use defmac::defmac;
13-
use num_traits::Float;
14-
15-
fn assert_approx_eq<F: fmt::Debug + Float>(f: F, g: F, tol: F) -> bool {
16-
assert!((f - g).abs() <= tol, "{:?} approx== {:?} (tol={:?})",
17-
f, g, tol);
18-
true
19-
}
20-
21-
fn assert_close<D>(a: ArrayView<f64, D>, b: ArrayView<f64, D>)
22-
where D: Dimension,
23-
{
24-
let diff = (&a - &b).mapv_into(f64::abs);
25-
26-
let rtol = 1e-7;
27-
let atol = 1e-12;
28-
let crtol = b.mapv(|x| x.abs() * rtol);
29-
let tol = crtol + atol;
30-
let tol_m_diff = &diff - &tol;
31-
let maxdiff = tol_m_diff.fold(0./0., |x, y| f64::max(x, *y));
32-
println!("diff offset from tolerance level= {:.2e}", maxdiff);
33-
if maxdiff > 0. {
34-
println!("{:.4?}", a);
35-
println!("{:.4?}", b);
36-
panic!("results differ");
37-
}
38-
}
3914

4015
fn reference_dot<'a,A, V1, V2>(a: V1, b: V2) -> A
4116
where A: NdFloat,
@@ -54,32 +29,32 @@ fn dot_product() {
5429
let a = Array::range(0., 69., 1.);
5530
let b = &a * 2. - 7.;
5631
let dot = 197846.;
57-
assert_approx_eq(a.dot(&b), reference_dot(&a, &b), 1e-5);
32+
assert_abs_diff_eq!(a.dot(&b), reference_dot(&a, &b), epsilon = 1e-5);
5833

5934
// test different alignments
6035
let max = 8 as Ixs;
6136
for i in 1..max {
6237
let a1 = a.slice(s![i..]);
6338
let b1 = b.slice(s![i..]);
64-
assert_approx_eq(a1.dot(&b1), reference_dot(&a1, &b1), 1e-5);
39+
assert_abs_diff_eq!(a1.dot(&b1), reference_dot(&a1, &b1), epsilon = 1e-5);
6540
let a2 = a.slice(s![..-i]);
6641
let b2 = b.slice(s![i..]);
67-
assert_approx_eq(a2.dot(&b2), reference_dot(&a2, &b2), 1e-5);
42+
assert_abs_diff_eq!(a2.dot(&b2), reference_dot(&a2, &b2), epsilon = 1e-5);
6843
}
6944

7045

7146
let a = a.map(|f| *f as f32);
7247
let b = b.map(|f| *f as f32);
73-
assert_approx_eq(a.dot(&b), dot as f32, 1e-5);
48+
assert_abs_diff_eq!(a.dot(&b), dot as f32, epsilon = 1e-5);
7449

7550
let max = 8 as Ixs;
7651
for i in 1..max {
7752
let a1 = a.slice(s![i..]);
7853
let b1 = b.slice(s![i..]);
79-
assert_approx_eq(a1.dot(&b1), reference_dot(&a1, &b1), 1e-5);
54+
assert_abs_diff_eq!(a1.dot(&b1), reference_dot(&a1, &b1), epsilon = 1e-5);
8055
let a2 = a.slice(s![..-i]);
8156
let b2 = b.slice(s![i..]);
82-
assert_approx_eq(a2.dot(&b2), reference_dot(&a2, &b2), 1e-5);
57+
assert_abs_diff_eq!(a2.dot(&b2), reference_dot(&a2, &b2), epsilon = 1e-5);
8358
}
8459

8560
let a = a.map(|f| *f as i32);
@@ -94,17 +69,17 @@ fn dot_product_0() {
9469
let x = 1.5;
9570
let b = aview0(&x);
9671
let b = b.broadcast(a.dim()).unwrap();
97-
assert_approx_eq(a.dot(&b), reference_dot(&a, &b), 1e-5);
72+
assert_abs_diff_eq!(a.dot(&b), reference_dot(&a, &b), epsilon = 1e-5);
9873

9974
// test different alignments
10075
let max = 8 as Ixs;
10176
for i in 1..max {
10277
let a1 = a.slice(s![i..]);
10378
let b1 = b.slice(s![i..]);
104-
assert_approx_eq(a1.dot(&b1), reference_dot(&a1, &b1), 1e-5);
79+
assert_abs_diff_eq!(a1.dot(&b1), reference_dot(&a1, &b1), epsilon = 1e-5);
10580
let a2 = a.slice(s![..-i]);
10681
let b2 = b.slice(s![i..]);
107-
assert_approx_eq(a2.dot(&b2), reference_dot(&a2, &b2), 1e-5);
82+
assert_abs_diff_eq!(a2.dot(&b2), reference_dot(&a2, &b2), epsilon = 1e-5);
10883
}
10984
}
11085

@@ -117,13 +92,13 @@ fn dot_product_neg_stride() {
11792
// both negative
11893
let a = a.slice(s![..;stride]);
11994
let b = b.slice(s![..;stride]);
120-
assert_approx_eq(a.dot(&b), reference_dot(&a, &b), 1e-5);
95+
assert_abs_diff_eq!(a.dot(&b), reference_dot(&a, &b), epsilon = 1e-5);
12196
}
12297
for stride in -10..0 {
12398
// mixed
12499
let a = a.slice(s![..;-stride]);
125100
let b = b.slice(s![..;stride]);
126-
assert_approx_eq(a.dot(&b), reference_dot(&a, &b), 1e-5);
101+
assert_abs_diff_eq!(a.dot(&b), reference_dot(&a, &b), epsilon = 1e-5);
127102
}
128103
}
129104

@@ -402,7 +377,7 @@ fn scaled_add_2() {
402377
answerv += &(beta * &c);
403378
av.scaled_add(beta, &c);
404379
}
405-
assert_close(a.view(), answer.view());
380+
assert_relative_eq!(a, answer, epsilon = 1e-12, max_relative = 1e-7);
406381
}
407382
}
408383
}
@@ -451,7 +426,7 @@ fn scaled_add_3() {
451426
answerv += &(beta * &c);
452427
av.scaled_add(beta, &c);
453428
}
454-
assert_close(a.view(), answer.view());
429+
assert_relative_eq!(a, answer, epsilon = 1e-12, max_relative = 1e-7);
455430
}
456431
}
457432
}
@@ -490,7 +465,7 @@ fn gen_mat_mul() {
490465

491466
general_mat_mul(alpha, &a, &b, beta, &mut cv);
492467
}
493-
assert_close(c.view(), answer.view());
468+
assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7);
494469
}
495470
}
496471
}
@@ -507,7 +482,7 @@ fn gemm_64_1_f() {
507482
let mut y = range_mat64(m, 1);
508483
let answer = reference_mat_mul(&a, &x) + &y;
509484
general_mat_mul(1.0, &a, &x, 1.0, &mut y);
510-
assert_close(y.view(), answer.view());
485+
assert_relative_eq!(y, answer, epsilon = 1e-12, max_relative = 1e-7);
511486
}
512487

513488
#[test]
@@ -572,7 +547,7 @@ fn gen_mat_vec_mul() {
572547

573548
general_mat_vec_mul(alpha, &a, &b, beta, &mut cv);
574549
}
575-
assert_close(c.view(), answer.view());
550+
assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7);
576551
}
577552
}
578553
}
@@ -614,7 +589,7 @@ fn vec_mat_mul() {
614589

615590
c.slice_mut(s![..;s2]).assign(&a.dot(&b));
616591
}
617-
assert_close(c.view(), answer.view());
592+
assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7);
618593
}
619594
}
620595
}

numeric-tests/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ authors = ["bluss"]
55
publish = false
66

77
[dependencies]
8-
ndarray = { path = ".." }
8+
approx = "0.3.2"
9+
ndarray = { path = "..", features = ["approx"] }
910
ndarray-rand = { path = "../ndarray-rand/" }
1011
rand = "0.6.0"
1112

0 commit comments

Comments
 (0)