Skip to content

Commit 9758af7

Browse files
kdubovikovKirill Dubovikov
and
Kirill Dubovikov
authored
Added implementation of var and std methods for ArrayBase (#790)
Add implementation of `var` and `std` methods that are a single-dimensional versions of `var_axis` and `std_axis` methods Co-authored-by: Kirill Dubovikov <[email protected]>
1 parent 35b3ec4 commit 9758af7

File tree

3 files changed

+177
-1
lines changed

3 files changed

+177
-1
lines changed

src/numeric/impl_numeric.rs

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ use crate::imp_prelude::*;
1313
use crate::itertools::enumerate;
1414
use crate::numeric_util;
1515

16+
use crate::{FoldWhile, Zip};
17+
1618
/// # Numerical Methods for Arrays
1719
impl<A, S, D> ArrayBase<S, D>
1820
where
@@ -111,6 +113,114 @@ where
111113
sum
112114
}
113115

116+
/// Return variance of elements in the array.
117+
///
118+
/// The variance is computed using the [Welford one-pass
119+
/// algorithm](https://www.jstor.org/stable/1266577).
120+
///
121+
/// The parameter `ddof` specifies the "delta degrees of freedom". For
122+
/// example, to calculate the population variance, use `ddof = 0`, or to
123+
/// calculate the sample variance, use `ddof = 1`.
124+
///
125+
/// The variance is defined as:
126+
///
127+
/// ```text
128+
/// 1 n
129+
/// variance = ―――――――― ∑ (xᵢ - x̅)²
130+
/// n - ddof i=1
131+
/// ```
132+
///
133+
/// where
134+
///
135+
/// ```text
136+
/// 1 n
137+
/// x̅ = ― ∑ xᵢ
138+
/// n i=1
139+
/// ```
140+
///
141+
/// and `n` is the length of the array.
142+
///
143+
/// **Panics** if `ddof` is less than zero or greater than `n`
144+
///
145+
/// # Example
146+
///
147+
/// ```
148+
/// use ndarray::array;
149+
/// use approx::assert_abs_diff_eq;
150+
///
151+
/// let a = array![1., -4.32, 1.14, 0.32];
152+
/// let var = a.var(1.);
153+
/// assert_abs_diff_eq!(var, 6.7331, epsilon = 1e-4);
154+
/// ```
155+
pub fn var(&self, ddof: A) -> A
156+
where
157+
A: Float + FromPrimitive,
158+
{
159+
let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
160+
let n = A::from_usize(self.len()).expect("Converting length to `A` must not fail.");
161+
assert!(
162+
!(ddof < zero || ddof > n),
163+
"`ddof` must not be less than zero or greater than the length of \
164+
the axis",
165+
);
166+
let dof = n - ddof;
167+
let mut mean = A::zero();
168+
let mut sum_sq = A::zero();
169+
for (i, &x) in self.into_iter().enumerate() {
170+
let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
171+
let delta = x - mean;
172+
mean = mean + delta / count;
173+
sum_sq = (x - mean).mul_add(delta, sum_sq);
174+
}
175+
sum_sq / dof
176+
}
177+
178+
/// Return standard deviation of elements in the array.
179+
///
180+
/// The standard deviation is computed from the variance using
181+
/// the [Welford one-pass algorithm](https://www.jstor.org/stable/1266577).
182+
///
183+
/// The parameter `ddof` specifies the "delta degrees of freedom". For
184+
/// example, to calculate the population standard deviation, use `ddof = 0`,
185+
/// or to calculate the sample standard deviation, use `ddof = 1`.
186+
///
187+
/// The standard deviation is defined as:
188+
///
189+
/// ```text
190+
/// ⎛ 1 n ⎞
191+
/// stddev = sqrt ⎜ ―――――――― ∑ (xᵢ - x̅)²⎟
192+
/// ⎝ n - ddof i=1 ⎠
193+
/// ```
194+
///
195+
/// where
196+
///
197+
/// ```text
198+
/// 1 n
199+
/// x̅ = ― ∑ xᵢ
200+
/// n i=1
201+
/// ```
202+
///
203+
/// and `n` is the length of the array.
204+
///
205+
/// **Panics** if `ddof` is less than zero or greater than `n`
206+
///
207+
/// # Example
208+
///
209+
/// ```
210+
/// use ndarray::array;
211+
/// use approx::assert_abs_diff_eq;
212+
///
213+
/// let a = array![1., -4.32, 1.14, 0.32];
214+
/// let stddev = a.std(1.);
215+
/// assert_abs_diff_eq!(stddev, 2.59483, epsilon = 1e-4);
216+
/// ```
217+
pub fn std(&self, ddof: A) -> A
218+
where
219+
A: Float + FromPrimitive,
220+
{
221+
self.var(ddof).sqrt()
222+
}
223+
114224
/// Return sum along `axis`.
115225
///
116226
/// ```

src/private.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,5 @@ macro_rules! private_impl {
2121
fn __private__(&self) -> crate::private::PrivateMarker {
2222
crate::private::PrivateMarker
2323
}
24-
}
24+
};
2525
}

tests/numeric.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,72 @@ fn sum_mean_empty() {
6464
assert_eq!(a, None);
6565
}
6666

67+
#[test]
68+
fn var() {
69+
let a = array![1., -4.32, 1.14, 0.32];
70+
assert_abs_diff_eq!(a.var(0.), 5.049875, epsilon = 1e-8);
71+
}
72+
73+
#[test]
74+
#[should_panic]
75+
fn var_negative_ddof() {
76+
let a = array![1., 2., 3.];
77+
a.var(-1.);
78+
}
79+
80+
#[test]
81+
#[should_panic]
82+
fn var_too_large_ddof() {
83+
let a = array![1., 2., 3.];
84+
a.var(4.);
85+
}
86+
87+
#[test]
88+
fn var_nan_ddof() {
89+
let a = Array2::<f64>::zeros((2, 3));
90+
let v = a.var(::std::f64::NAN);
91+
assert!(v.is_nan());
92+
}
93+
94+
#[test]
95+
fn var_empty_arr() {
96+
let a: Array1<f64> = array![];
97+
assert!(a.var(0.0).is_nan());
98+
}
99+
100+
#[test]
101+
fn std() {
102+
let a = array![1., -4.32, 1.14, 0.32];
103+
assert_abs_diff_eq!(a.std(0.), 2.24719, epsilon = 1e-5);
104+
}
105+
106+
#[test]
107+
#[should_panic]
108+
fn std_negative_ddof() {
109+
let a = array![1., 2., 3.];
110+
a.std(-1.);
111+
}
112+
113+
#[test]
114+
#[should_panic]
115+
fn std_too_large_ddof() {
116+
let a = array![1., 2., 3.];
117+
a.std(4.);
118+
}
119+
120+
#[test]
121+
fn std_nan_ddof() {
122+
let a = Array2::<f64>::zeros((2, 3));
123+
let v = a.std(::std::f64::NAN);
124+
assert!(v.is_nan());
125+
}
126+
127+
#[test]
128+
fn std_empty_arr() {
129+
let a: Array1<f64> = array![];
130+
assert!(a.std(0.0).is_nan());
131+
}
132+
67133
#[test]
68134
#[cfg(feature = "approx")]
69135
fn var_axis() {

0 commit comments

Comments
 (0)