Skip to content

Commit be9a11b

Browse files
committed
Covariance-related functions for general AbstractArray
1 parent c5328b1 commit be9a11b

File tree

2 files changed

+88
-51
lines changed

2 files changed

+88
-51
lines changed

src/cov.jl

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# auxiliary functions
44

5-
function _symmetrize!(a::DenseMatrix)
5+
function _symmetrize!(a::AbstractMatrix)
66
m, n = size(a)
77
m == n || error("a must be a square matrix.")
88
for j = 1:n
@@ -15,20 +15,31 @@ function _symmetrize!(a::DenseMatrix)
1515
return a
1616
end
1717

18-
function _scalevars(x::DenseMatrix, s::AbstractWeights, dims::Int)
18+
function _symmetrize!(a::AbstractSparseMatrix)
19+
m, n = size(a)
20+
m == n || error("a must be a square matrix.")
21+
for (i,j,vl) in zip(findnz(a)...)
22+
i > j || continue
23+
vr = a[j,i]
24+
a[i,j] = a[j,i] = middle(vl, vr)
25+
end
26+
return a
27+
end
28+
29+
function _scalevars(x::AbstractMatrix, s::AbstractWeights, dims::Int)
1930
dims == 1 ? Diagonal(s) * x :
2031
dims == 2 ? x * Diagonal(s) :
2132
error("dims should be either 1 or 2.")
2233
end
2334

2435
## scatter matrix
2536

26-
_unscaled_covzm(x::DenseMatrix, dims::Colon) = unscaled_covzm(x)
27-
_unscaled_covzm(x::DenseMatrix, dims::Integer) = unscaled_covzm(x, dims)
37+
_unscaled_covzm(x::AbstractMatrix, dims::Colon) = unscaled_covzm(x)
38+
_unscaled_covzm(x::AbstractMatrix, dims::Integer) = unscaled_covzm(x, dims)
2839

29-
_unscaled_covzm(x::DenseMatrix, wv::AbstractWeights, dims::Colon) =
40+
_unscaled_covzm(x::AbstractMatrix, wv::AbstractWeights, dims::Colon) =
3041
_symmetrize!(unscaled_covzm(x, _scalevars(x, wv)))
31-
_unscaled_covzm(x::DenseMatrix, wv::AbstractWeights, dims::Integer) =
42+
_unscaled_covzm(x::AbstractMatrix, wv::AbstractWeights, dims::Integer) =
3243
_symmetrize!(unscaled_covzm(x, _scalevars(x, wv, dims), dims))
3344

3445
"""
@@ -75,30 +86,29 @@ Finally, bias correction is applied to the covariance calculation if
7586
"""
7687
function mean_and_cov end
7788

78-
scattermat(x::DenseMatrix; mean=nothing, dims::Int=1) =
89+
scattermat(x::AbstractMatrix; mean=nothing, dims::Int=1) =
7990
_scattermatm(x, mean, dims)
80-
_scattermatm(x::DenseMatrix, ::Nothing, dims::Int) =
91+
_scattermatm(x::AbstractMatrix, ::Nothing, dims::Int) =
8192
_unscaled_covzm(x .- mean(x, dims=dims), dims)
82-
_scattermatm(x::DenseMatrix, mean, dims::Int=1) =
93+
_scattermatm(x::AbstractMatrix, mean, dims::Int=1) =
8394
_unscaled_covzm(x .- mean, dims)
8495

85-
scattermat(x::DenseMatrix, wv::AbstractWeights; mean=nothing, dims::Int=1) =
96+
scattermat(x::AbstractMatrix, wv::AbstractWeights; mean=nothing, dims::Int=1) =
8697
_scattermatm(x, wv, mean, dims)
87-
_scattermatm(x::DenseMatrix, wv::AbstractWeights, ::Nothing, dims::Int) =
98+
_scattermatm(x::AbstractMatrix, wv::AbstractWeights, ::Nothing, dims::Int) =
8899
_unscaled_covzm(x .- mean(x, wv, dims=dims), wv, dims)
89-
_scattermatm(x::DenseMatrix, wv::AbstractWeights, mean, dims::Int) =
100+
_scattermatm(x::AbstractMatrix, wv::AbstractWeights, mean, dims::Int) =
90101
_unscaled_covzm(x .- mean, wv, dims)
91102

92103
## weighted cov
93-
covm(x::DenseMatrix, mean, w::AbstractWeights, dims::Int=1;
104+
covm(x::AbstractMatrix, mean, w::AbstractWeights, dims::Int=1;
94105
corrected::DepBool=nothing) =
95106
rmul!(scattermat(x, w, mean=mean, dims=dims), varcorrection(w, depcheck(:covm, corrected)))
96107

97-
98-
cov(x::DenseMatrix, w::AbstractWeights, dims::Int=1; corrected::DepBool=nothing) =
108+
cov(x::AbstractMatrix, w::AbstractWeights, dims::Int=1; corrected::DepBool=nothing) =
99109
covm(x, mean(x, w, dims=dims), w, dims; corrected=depcheck(:cov, corrected))
100110

101-
function corm(x::DenseMatrix, mean, w::AbstractWeights, vardim::Int=1)
111+
function corm(x::AbstractMatrix, mean, w::AbstractWeights, vardim::Int=1)
102112
c = covm(x, mean, w, vardim; corrected=false)
103113
s = stdm(x, w, mean, vardim; corrected=false)
104114
cov2cor!(c, s)
@@ -110,19 +120,24 @@ end
110120
Compute the Pearson correlation matrix of `X` along the dimension
111121
`dims` with a weighting `w` .
112122
"""
113-
cor(x::DenseMatrix, w::AbstractWeights, dims::Int=1) =
123+
cor(x::AbstractMatrix, w::AbstractWeights, dims::Int=1) =
114124
corm(x, mean(x, w, dims=dims), w, dims)
115125

116-
function mean_and_cov(x::DenseMatrix, dims::Int=1; corrected::Bool=true)
126+
function mean_and_cov(x::AbstractMatrix, dims::Int=1; corrected::Bool=true)
117127
m = mean(x, dims=dims)
118128
return m, covm(x, m, dims, corrected=corrected)
119129
end
120-
function mean_and_cov(x::DenseMatrix, wv::AbstractWeights, dims::Int=1;
130+
function mean_and_cov(x::AbstractMatrix, wv::AbstractWeights, dims::Int=1;
121131
corrected::DepBool=nothing)
122132
m = mean(x, wv, dims=dims)
123133
return m, cov(x, wv, dims; corrected=depcheck(:mean_and_cov, corrected))
124134
end
125135

136+
function mean_and_cov(x::AbstractVector; corrected::Bool=true)
137+
m = mean(x)
138+
return m, covm(x, m, corrected=corrected)
139+
end
140+
126141
"""
127142
cov2cor(C, s)
128143
@@ -148,8 +163,10 @@ standard deviations `s`.
148163
function cor2cov!(C::AbstractMatrix, s::AbstractArray)
149164
n = length(s)
150165
size(C) == (n, n) || throw(DimensionMismatch("inconsistent dimensions"))
151-
for i in CartesianIndices(size(C))
152-
@inbounds C[i] *= s[i[1]] * s[i[2]]
166+
@inbounds for i in CartesianIndices(size(C))
167+
si = s[i[1]] * s[i[2]]
168+
# the covariance is 0 when si==0, although C[i] is NaN in this case
169+
C[i] = iszero(si) ? zero(eltype(C)) : C[i] * si
153170
end
154171
return C
155172
end

test/cov.jl

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,26 @@
11
using StatsBase
22
using LinearAlgebra, Random, Test
3+
using SparseArrays
34

45
struct EmptyCovarianceEstimator <: CovarianceEstimator end
56

7+
struct WrappedArray{T,N,A} <: AbstractArray{T,N}
8+
a::A
9+
WrappedArray(a::AbstractArray{T,N}) where {T,N} = new{T,N,typeof(a)}(a)
10+
end
11+
Base.size(w::WrappedArray) = size(w.a)
12+
Base.getindex(w::WrappedArray{T,N}, I::Vararg{Int, N}) where {T,N} = getindex(w.a, I...)
13+
Base.setindex!(w::WrappedArray{T,N}, v, I::Vararg{Int, N}) where {T,N} = setindex!(w.a, v, I...)
14+
15+
(x,y) = isapprox(x, y; nans=true)
16+
617
@testset "StatsBase.Covariance" begin
718
weight_funcs = (weights, aweights, fweights, pweights)
19+
array = randn(3, 8)
20+
wrapped_array = WrappedArray(array)
21+
sparse_array = sprandn(3, 8, 0.2)
822

9-
@testset "$f" for f in weight_funcs
10-
X = randn(3, 8)
23+
@testset "$f,$(typeof(X))" for f in weight_funcs, X in (array, wrapped_array, sparse_array)
1124

1225
Z1 = X .- mean(X, dims = 1)
1326
Z2 = X .- mean(X, dims = 2)
@@ -87,27 +100,32 @@ weight_funcs = (weights, aweights, fweights, pweights)
87100
@testset "Mean and covariance" begin
88101
(m, C) = mean_and_cov(X; corrected=false)
89102
@test m == mean(X, dims=1)
90-
@test C == cov(X, dims=1, corrected=false)
103+
@test C cov(X, dims=1, corrected=false)
91104

92105
(m, C) = mean_and_cov(X, 1; corrected=false)
93106
@test m == mean(X, dims=1)
94-
@test C == cov(X, dims=1, corrected = false)
107+
@test C cov(X, dims=1, corrected = false)
95108

96109
(m, C) = mean_and_cov(X, 2; corrected=false)
97110
@test m == mean(X, dims=2)
98-
@test C == cov(X, dims=2, corrected = false)
111+
@test C cov(X, dims=2, corrected = false)
99112

100113
(m, C) = mean_and_cov(X, wv1; corrected=false)
101114
@test m == mean(X, wv1, dims=1)
102-
@test C == cov(X, wv1, 1, corrected=false)
115+
@test C cov(X, wv1, 1, corrected=false)
103116

104117
(m, C) = mean_and_cov(X, wv1, 1; corrected=false)
105118
@test m == mean(X, wv1, dims=1)
106-
@test C == cov(X, wv1, 1, corrected=false)
119+
@test C cov(X, wv1, 1, corrected=false)
107120

108121
(m, C) = mean_and_cov(X, wv2, 2; corrected=false)
109122
@test m == mean(X, wv2, dims=2)
110-
@test C == cov(X, wv2, 2, corrected=false)
123+
@test C cov(X, wv2, 2, corrected=false)
124+
125+
v = collect(eachcol(X))
126+
(m, C) = mean_and_cov(v; corrected=false)
127+
@test m == mean(v)
128+
@test C cov(v, corrected=false)
111129
end
112130
@testset "Conversions" begin
113131
std1 = std(X, wv1, 1; corrected=false)
@@ -119,11 +137,12 @@ weight_funcs = (weights, aweights, fweights, pweights)
119137
cor1 = cor(X, wv1, 1)
120138
cor2 = cor(X, wv2, 2)
121139

140+
122141
@testset "cov2cor" begin
123-
@test cov2cor(cov(X, dims = 1), std(X, dims = 1)) cor(X, dims = 1)
124-
@test cov2cor(cov(X, dims = 2), std(X, dims = 2)) cor(X, dims = 2)
125-
@test cov2cor(cov1, std1) cor1
126-
@test cov2cor(cov2, std2) cor2
142+
@test cov2cor(cov(X, dims = 1), std(X, dims = 1)) cor(X, dims = 1)
143+
@test cov2cor(cov(X, dims = 2), std(X, dims = 2)) cor(X, dims = 2)
144+
@test cov2cor(cov1, std1) cor1
145+
@test cov2cor(cov2, std2) cor2
127146
end
128147
@testset "cor2cov" begin
129148
@test cor2cov(cor(X, dims = 1), std(X, dims = 1)) cov(X, dims = 1)
@@ -158,30 +177,30 @@ weight_funcs = (weights, aweights, fweights, pweights)
158177
@testset "Mean and covariance" begin
159178
(m, C) = mean_and_cov(X; corrected=true)
160179
@test m == mean(X, dims=1)
161-
@test C == cov(X, dims=1, corrected = true)
180+
@test C cov(X, dims=1, corrected = true)
162181

163182
(m, C) = mean_and_cov(X, 1; corrected=true)
164183
@test m == mean(X, dims=1)
165-
@test C == cov(X, dims=1, corrected = true)
184+
@test C cov(X, dims=1, corrected = true)
166185

167186
(m, C) = mean_and_cov(X, 2; corrected=true)
168187
@test m == mean(X, dims=2)
169-
@test C == cov(X, dims=2, corrected = true)
188+
@test C cov(X, dims=2, corrected = true)
170189

171190
if isa(wv1, Weights)
172191
@test_throws ArgumentError mean_and_cov(X, wv1; corrected=true)
173192
else
174193
(m, C) = mean_and_cov(X, wv1; corrected=true)
175194
@test m == mean(X, wv1, dims=1)
176-
@test C == cov(X, wv1, 1; corrected=true)
195+
@test C cov(X, wv1, 1; corrected=true)
177196

178197
(m, C) = mean_and_cov(X, wv1, 1; corrected=true)
179198
@test m == mean(X, wv1, dims=1)
180-
@test C == cov(X, wv1, 1; corrected=true)
199+
@test C cov(X, wv1, 1; corrected=true)
181200

182201
(m, C) = mean_and_cov(X, wv2, 2; corrected=true)
183202
@test m == mean(X, wv2, dims=2)
184-
@test C == cov(X, wv2, 2; corrected=true)
203+
@test C cov(X, wv2, 2; corrected=true)
185204
end
186205
end
187206
@testset "Conversions" begin
@@ -196,25 +215,26 @@ weight_funcs = (weights, aweights, fweights, pweights)
196215
cor2 = cor(X, wv2, 2)
197216

198217
@testset "cov2cor" begin
199-
@test cov2cor(cov(X, dims = 1), std(X, dims = 1)) cor(X, dims = 1)
200-
@test cov2cor(cov(X, dims = 2), std(X, dims = 2)) cor(X, dims = 2)
201-
@test cov2cor(cov1, std1) cor1
202-
@test cov2cor(cov2, std2) cor2
218+
@test cov2cor(cov(X, dims = 1), std(X, dims = 1)) cor(X, dims = 1)
219+
@test cov2cor(cov(X, dims = 2), std(X, dims = 2)) cor(X, dims = 2)
220+
@test cov2cor(cov1, std1) cor1
221+
@test cov2cor(cov2, std2) cor2
203222
end
204223

205224
@testset "cov2cor!" begin
206225
tmp_cov1 = copy(cov1)
207-
@test !(tmp_cov1 cor1)
226+
@test !(tmp_cov1 cor1)
208227
StatsBase.cov2cor!(tmp_cov1, std1)
209-
@test tmp_cov1 cor1
228+
@test tmp_cov1 cor1
210229

211230
tmp_cov2 = copy(cov2)
212-
@test !(tmp_cov2 cor2)
231+
@test !(tmp_cov2 cor2)
213232
StatsBase.cov2cor!(tmp_cov2, std2)
214-
@test tmp_cov2 cor2
233+
@test tmp_cov2 cor2
215234
end
216235

217236
@testset "cor2cov" begin
237+
218238
@test cor2cov(cor(X, dims = 1), std(X, dims = 1)) cov(X, dims = 1)
219239
@test cor2cov(cor(X, dims = 2), std(X, dims = 2)) cov(X, dims = 2)
220240
@test cor2cov(cor1, std1) cov1
@@ -237,8 +257,8 @@ weight_funcs = (weights, aweights, fweights, pweights)
237257
end
238258

239259
@testset "Correlation" begin
240-
@test cor(X, f(ones(3)), 1) cor(X, dims = 1)
241-
@test cor(X, f(ones(8)), 2) cor(X, dims = 2)
260+
@test cor(X, f(ones(3)), 1) cor(X, dims = 1)
261+
@test cor(X, f(ones(8)), 2) cor(X, dims = 2)
242262

243263
cov1 = cov(X, wv1, 1; corrected=false)
244264
std1 = std(X, wv1, 1; corrected=false)
@@ -247,8 +267,8 @@ weight_funcs = (weights, aweights, fweights, pweights)
247267
expected_cor1 = StatsBase.cov2cor!(cov1, std1)
248268
expected_cor2 = StatsBase.cov2cor!(cov2, std2)
249269

250-
@test cor(X, wv1, 1) expected_cor1
251-
@test cor(X, wv2, 2) expected_cor2
270+
@test cor(X, wv1, 1) expected_cor1
271+
@test cor(X, wv2, 2) expected_cor2
252272
end
253273

254274
@testset "Abstract covariance estimation" begin

0 commit comments

Comments
 (0)