Skip to content

Commit f3f7b18

Browse files
make SparseArrays a weak dependency
1 parent e9ac70b commit f3f7b18

File tree

3 files changed

+113
-90
lines changed

3 files changed

+113
-90
lines changed

Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,16 @@ version = "1.9.0"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1111

12+
[weakdeps]
13+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
14+
15+
[extensions]
16+
SparseArraysExt = ["SparseArrays"]
17+
1218
[extras]
1319
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
20+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1421
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1522

1623
[targets]
17-
test = ["Random", "Test"]
24+
test = ["Random", "SparseArrays", "Test"]

ext/SparseArraysExt.jl

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
module SparseArraysExt
2+
3+
##### SparseArrays optimizations #####
4+
5+
using Base: require_one_based_indexing
6+
using LinearAlgebra
7+
using SparseArrays
8+
using Statistics
9+
using Statistics: centralize_sumabs2, unscaled_covzm
10+
11+
# extended functions
12+
import Statistics: cov, centralize_sumabs2!
13+
14+
function cov(X::SparseMatrixCSC; dims::Int=1, corrected::Bool=true)
15+
vardim = dims
16+
a, b = size(X)
17+
n, p = vardim == 1 ? (a, b) : (b, a)
18+
19+
# The covariance can be decomposed into two terms
20+
# 1/(n - 1) ∑ (x_i - x̄)*(x_i - x̄)' = 1/(n - 1) (∑ x_i*x_i' - n*x̄*x̄')
21+
# which can be evaluated via a sparse matrix-matrix product
22+
23+
# Compute ∑ x_i*x_i' = X'X using sparse matrix-matrix product
24+
out = Matrix(unscaled_covzm(X, vardim))
25+
26+
# Compute x̄
27+
x̄ᵀ = mean(X, dims=vardim)
28+
29+
# Subtract n*x̄*x̄' from X'X
30+
@inbounds for j in 1:p, i in 1:p
31+
out[i,j] -= x̄ᵀ[i] * x̄ᵀ[j]' * n
32+
end
33+
34+
# scale with the sample size n or the corrected sample size n - 1
35+
return rmul!(out, inv(n - corrected))
36+
end
37+
38+
# This is the function that does the reduction underlying var/std
39+
function centralize_sumabs2!(R::AbstractArray{S}, A::SparseMatrixCSC{Tv,Ti}, means::AbstractArray) where {S,Tv,Ti}
40+
require_one_based_indexing(R, A, means)
41+
lsiz = Base.check_reducedims(R,A)
42+
for i in 1:max(ndims(R), ndims(means))
43+
if axes(means, i) != axes(R, i)
44+
throw(DimensionMismatch("dimension $i of `mean` should have indices $(axes(R, i)), but got $(axes(means, i))"))
45+
end
46+
end
47+
isempty(R) || fill!(R, zero(S))
48+
isempty(A) && return R
49+
50+
rowval = rowvals(A)
51+
nzval = nonzeros(A)
52+
m = size(A, 1)
53+
n = size(A, 2)
54+
55+
if size(R, 1) == size(R, 2) == 1
56+
# Reduction along both columns and rows
57+
R[1, 1] = centralize_sumabs2(A, means[1])
58+
elseif size(R, 1) == 1
59+
# Reduction along rows
60+
@inbounds for col = 1:n
61+
mu = means[col]
62+
r = convert(S, (m - length(nzrange(A, col)))*abs2(mu))
63+
@simd for j = nzrange(A, col)
64+
r += abs2(nzval[j] - mu)
65+
end
66+
R[1, col] = r
67+
end
68+
elseif size(R, 2) == 1
69+
# Reduction along columns
70+
rownz = fill(convert(Ti, n), m)
71+
@inbounds for col = 1:n
72+
@simd for j = nzrange(A, col)
73+
row = rowval[j]
74+
R[row, 1] += abs2(nzval[j] - means[row])
75+
rownz[row] -= 1
76+
end
77+
end
78+
for i = 1:m
79+
R[i, 1] += rownz[i]*abs2(means[i])
80+
end
81+
else
82+
# Reduction along a dimension > 2
83+
@inbounds for col = 1:n
84+
lastrow = 0
85+
@simd for j = nzrange(A, col)
86+
row = rowval[j]
87+
for i = lastrow+1:row-1
88+
R[i, col] = abs2(means[i, col])
89+
end
90+
R[row, col] = abs2(nzval[j] - means[row, col])
91+
lastrow = row
92+
end
93+
for i = lastrow+1:m
94+
R[i, col] = abs2(means[i, col])
95+
end
96+
end
97+
end
98+
return R
99+
end
100+
101+
end # module

src/Statistics.jl

Lines changed: 4 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Standard library module for basic statistics functionality.
77
"""
88
module Statistics
99

10-
using LinearAlgebra, SparseArrays
10+
using LinearAlgebra
1111

1212
using Base: has_offset_axes, require_one_based_indexing
1313

@@ -1073,94 +1073,9 @@ quantile(itr, p; sorted::Bool=false, alpha::Real=1.0, beta::Real=alpha) =
10731073
quantile(v::AbstractVector, p; sorted::Bool=false, alpha::Real=1.0, beta::Real=alpha) =
10741074
quantile!(sorted ? v : Base.copymutable(v), p; sorted=sorted, alpha=alpha, beta=beta)
10751075

1076-
1077-
##### SparseArrays optimizations #####
1078-
1079-
function cov(X::SparseMatrixCSC; dims::Int=1, corrected::Bool=true)
1080-
vardim = dims
1081-
a, b = size(X)
1082-
n, p = vardim == 1 ? (a, b) : (b, a)
1083-
1084-
# The covariance can be decomposed into two terms
1085-
# 1/(n - 1) ∑ (x_i - x̄)*(x_i - x̄)' = 1/(n - 1) (∑ x_i*x_i' - n*x̄*x̄')
1086-
# which can be evaluated via a sparse matrix-matrix product
1087-
1088-
# Compute ∑ x_i*x_i' = X'X using sparse matrix-matrix product
1089-
out = Matrix(unscaled_covzm(X, vardim))
1090-
1091-
# Compute x̄
1092-
x̄ᵀ = mean(X, dims=vardim)
1093-
1094-
# Subtract n*x̄*x̄' from X'X
1095-
@inbounds for j in 1:p, i in 1:p
1096-
out[i,j] -= x̄ᵀ[i] * x̄ᵀ[j]' * n
1097-
end
1098-
1099-
# scale with the sample size n or the corrected sample size n - 1
1100-
return rmul!(out, inv(n - corrected))
1101-
end
1102-
1103-
# This is the function that does the reduction underlying var/std
1104-
function centralize_sumabs2!(R::AbstractArray{S}, A::SparseMatrixCSC{Tv,Ti}, means::AbstractArray) where {S,Tv,Ti}
1105-
require_one_based_indexing(R, A, means)
1106-
lsiz = Base.check_reducedims(R,A)
1107-
for i in 1:max(ndims(R), ndims(means))
1108-
if axes(means, i) != axes(R, i)
1109-
throw(DimensionMismatch("dimension $i of `mean` should have indices $(axes(R, i)), but got $(axes(means, i))"))
1110-
end
1111-
end
1112-
isempty(R) || fill!(R, zero(S))
1113-
isempty(A) && return R
1114-
1115-
rowval = rowvals(A)
1116-
nzval = nonzeros(A)
1117-
m = size(A, 1)
1118-
n = size(A, 2)
1119-
1120-
if size(R, 1) == size(R, 2) == 1
1121-
# Reduction along both columns and rows
1122-
R[1, 1] = centralize_sumabs2(A, means[1])
1123-
elseif size(R, 1) == 1
1124-
# Reduction along rows
1125-
@inbounds for col = 1:n
1126-
mu = means[col]
1127-
r = convert(S, (m - length(nzrange(A, col)))*abs2(mu))
1128-
@simd for j = nzrange(A, col)
1129-
r += abs2(nzval[j] - mu)
1130-
end
1131-
R[1, col] = r
1132-
end
1133-
elseif size(R, 2) == 1
1134-
# Reduction along columns
1135-
rownz = fill(convert(Ti, n), m)
1136-
@inbounds for col = 1:n
1137-
@simd for j = nzrange(A, col)
1138-
row = rowval[j]
1139-
R[row, 1] += abs2(nzval[j] - means[row])
1140-
rownz[row] -= 1
1141-
end
1142-
end
1143-
for i = 1:m
1144-
R[i, 1] += rownz[i]*abs2(means[i])
1145-
end
1146-
else
1147-
# Reduction along a dimension > 2
1148-
@inbounds for col = 1:n
1149-
lastrow = 0
1150-
@simd for j = nzrange(A, col)
1151-
row = rowval[j]
1152-
for i = lastrow+1:row-1
1153-
R[i, col] = abs2(means[i, col])
1154-
end
1155-
R[row, col] = abs2(nzval[j] - means[row, col])
1156-
lastrow = row
1157-
end
1158-
for i = lastrow+1:m
1159-
R[i, col] = abs2(means[i, col])
1160-
end
1161-
end
1162-
end
1163-
return R
1076+
## If package extensions are not supported in this Julia version
1077+
if !isdefined(Base, :get_extension)
1078+
include("../ext/SparseArraysExt.jl")
11641079
end
11651080

11661081
end # module

0 commit comments

Comments
 (0)