Skip to content

Commit f31afb5

Browse files
committed
WIP: implement reductions as reductions
Depends upon (and is a requirement for) Julia issue 55318
1 parent 128dc11 commit f31afb5

File tree

2 files changed

+113
-85
lines changed

2 files changed

+113
-85
lines changed

src/Statistics.jl

Lines changed: 79 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,18 @@ julia> mean(skipmissing([1, missing, 3]))
4343
"""
4444
mean(itr) = mean(identity, itr)
4545

46+
struct Counter{F} <: Function
47+
f::F
48+
n::Base.RefValue{Int}
49+
end
50+
Counter(f::F) where {F} = Counter{F}(f, Ref(0))
51+
(f::Counter)(x) = (f.n[] += 1; f.f(x))
52+
53+
struct DivOne{F} <: Function
54+
f::F
55+
end
56+
(f::DivOne)(x) = f.f(x)/1
57+
4658
"""
4759
mean(f, itr)
4860
@@ -59,23 +71,13 @@ julia> mean([√1, √2, √3])
5971
```
6072
"""
6173
function mean(f, itr)
62-
y = iterate(itr)
63-
if y === nothing
64-
return Base.mapreduce_empty_iter(f, +, itr,
65-
Base.IteratorEltype(itr)) / 0
66-
end
67-
count = 1
68-
value, state = y
69-
f_value = f(value)/1
70-
total = Base.reduce_first(+, f_value)
71-
y = iterate(itr, state)
72-
while y !== nothing
73-
value, state = y
74-
total += _mean_promote(total, f(value))
75-
count += 1
76-
y = iterate(itr, state)
74+
if Base.IteratorSize(itr) === Base.SizeUnknown()
75+
g = Counter(DivOne(f))
76+
result = mapfoldl(g, add_mean, itr)
77+
return result/g.n[]
78+
else
79+
return mapfoldl(DivOne(f), add_mean, itr)/length(itr)
7780
end
78-
return total/count
7981
end
8082

8183
"""
@@ -180,20 +182,24 @@ mean(A::AbstractArray; dims=:) = _mean(identity, A, dims)
180182

181183
_mean_promote(x::T, y::S) where {T,S} = convert(promote_type(T, S), y)
182184

185+
add_mean(x, y) = Base.add_sum(x, _mean_promote(x, y))
186+
187+
Base.reduce_empty(::typeof(add_mean), T) = Base.reduce_empty(Base.add_sum, T)
188+
Base.mapreduce_empty(g::DivOne, ::typeof(add_mean), T) = Base.mapreduce_empty(g.f, Base.add_sum, T)/1
189+
Base.mapreduce_empty(g::Counter{<:DivOne}, ::typeof(add_mean), T) = Base.mapreduce_empty(g.f.f, Base.add_sum, T)/1
190+
191+
183192
# ::Dims is there to force specializing on Colon (as it is a Function)
184193
function _mean(f, A::AbstractArray, dims::Dims=:) where Dims
185-
isempty(A) && return sum(f, A, dims=dims)/0
186194
if dims === (:)
195+
result = mapreduce(DivOne(f), add_mean, A, dims=dims)
187196
n = length(A)
188-
else
189-
n = mapreduce(i -> size(A, i), *, unique(dims); init=1)
190-
end
191-
x1 = f(first(A)) / 1
192-
result = sum(x -> _mean_promote(x1, f(x)), A, dims=dims)
193-
if dims === (:)
194197
return result / n
195198
else
196-
return result ./= n
199+
result = mapreduce(DivOne(f), add_mean, A, dims=dims)
200+
n = prod(i -> size(A, i), unique(dims); init=1)
201+
result ./= n
202+
return result
197203
end
198204
end
199205

@@ -211,6 +217,7 @@ realXcY(x::Complex, y::Complex) = real(x)*real(y) + imag(x)*imag(y)
211217
var(iterable; corrected::Bool=true, mean=nothing) = _var(iterable, corrected, mean)
212218

213219
function _var(iterable, corrected::Bool, mean)
220+
ismissing(mean) && return missing
214221
y = iterate(iterable)
215222
if y === nothing
216223
T = eltype(iterable)
@@ -252,61 +259,36 @@ function _var(iterable, corrected::Bool, mean)
252259
end
253260
end
254261

255-
centralizedabs2fun(m) = x -> abs2.(x - m)
256-
centralize_sumabs2(A::AbstractArray, m) =
257-
mapreduce(centralizedabs2fun(m), +, A)
258-
centralize_sumabs2(A::AbstractArray, m, ifirst::Int, ilast::Int) =
259-
Base.mapreduce_impl(centralizedabs2fun(m), +, A, ifirst, ilast)
260-
261-
function centralize_sumabs2!(R::AbstractArray{S}, A::AbstractArray, means::AbstractArray) where S
262-
# following the implementation of _mapreducedim! at base/reducedim.jl
263-
lsiz = Base.check_reducedims(R,A)
264-
for i in 1:max(ndims(R), ndims(means))
265-
if axes(means, i) != axes(R, i)
266-
throw(DimensionMismatch("dimension $i of `mean` should have indices $(axes(R, i)), but got $(axes(means, i))"))
267-
end
268-
end
269-
isempty(R) || fill!(R, zero(S))
270-
isempty(A) && return R
271-
272-
if Base.has_fast_linear_indexing(A) && lsiz > 16 && !has_offset_axes(R, means)
273-
nslices = div(length(A), lsiz)
274-
ibase = first(LinearIndices(A))-1
275-
for i = 1:nslices
276-
@inbounds R[i] = centralize_sumabs2(A, means[i], ibase+1, ibase+lsiz)
277-
ibase += lsiz
278-
end
279-
return R
280-
end
281-
indsAt, indsRt = Base.safe_tail(axes(A)), Base.safe_tail(axes(R)) # handle d=1 manually
282-
keep, Idefault = Broadcast.shapeindexer(indsRt)
283-
if Base.reducedim1(R, A)
284-
i1 = first(Base.axes1(R))
285-
@inbounds for IA in CartesianIndices(indsAt)
286-
IR = Broadcast.newindex(IA, keep, Idefault)
287-
r = R[i1,IR]
288-
m = means[i1,IR]
289-
@simd for i in axes(A, 1)
290-
r += abs2(A[i,IA] - m)
291-
end
292-
R[i1,IR] = r
293-
end
294-
else
295-
@inbounds for IA in CartesianIndices(indsAt)
296-
IR = Broadcast.newindex(IA, keep, Idefault)
297-
@simd for i in axes(A, 1)
298-
R[i,IR] += abs2(A[i,IA] - means[i,IR])
299-
end
300-
end
301-
end
302-
return R
262+
struct CentralizedAbs2Fun{T,S} <: Function
263+
mean::S
303264
end
265+
CentralizedAbs2Fun{T}(means) where {T} = CentralizedAbs2Fun{T,typeof(means)}(means)
266+
CentralizedAbs2Fun(means) = CentralizedAbs2Fun{typeof(means)}(means)
267+
CentralizedAbs2Fun(means, extrude) = CentralizedAbs2Fun{eltype(means)}(Broadcast.extrude(means))
268+
# Division is generally costly, but Julia is typically able to constant propagate a /1
269+
# and simply ensure we get the type right at no cost, allowing the division in-place later
270+
(f::CentralizedAbs2Fun)(x) = abs2.(x - f.mean)/1
271+
(f::CentralizedAbs2Fun{<:Any,<:Broadcast.Extruded})((i, x),) = abs2.(x - Broadcast._broadcast_getindex(f.mean, i))/1
272+
_doubled(x) = x+x
273+
Base.mapreduce_empty(::CentralizedAbs2Fun{T,<:Broadcast.Extruded}, ::typeof(Base.add_sum), ::Type{Tuple{_Any,S}}) where {T<:Number, S<:Number, _Any} = _doubled(abs2(zero(T)-zero(S)))/1
274+
Base.mapreduce_empty(::CentralizedAbs2Fun{T,<:Broadcast.Extruded}, ::typeof(Base.add_sum), ::Type{Tuple{_Any, Union{Missing, S}}}) where {T<:Number, S<:Number, _Any} = _doubled(abs2(zero(T)-zero(S)))/1
275+
Base.mapreduce_empty(::CentralizedAbs2Fun{T}, ::typeof(Base.add_sum), ::Type{S}) where {T<:Number, S<:Number} = _doubled(abs2(zero(T)-zero(S)))/1
276+
Base.mapreduce_empty(::CentralizedAbs2Fun{T}, ::typeof(Base.add_sum), ::Type{Union{Missing, S}}) where {T<:Number, S<:Number} = _doubled(abs2(zero(T)-zero(S)))/1
277+
278+
centralize_sumabs2(A::AbstractArray, m) =
279+
sum(CentralizedAbs2Fun(m), A)
280+
centralize_sumabs2(A::AbstractArray, m::AbstractArray, region) =
281+
sum(CentralizedAbs2Fun(m, true), Base.PairsArray(A), dims=region)
282+
centralize_sumabs2!(R::AbstractArray, A::AbstractArray, means::AbstractArray) =
283+
sum!(CentralizedAbs2Fun(means, true), R, Base.PairsArray(A))
284+
304285

305286
function varm!(R::AbstractArray{S}, A::AbstractArray, m::AbstractArray; corrected::Bool=true) where S
306-
if isempty(A)
287+
_checkm(R, m, ntuple(identity, Val(max(ndims(R), ndims(m)))))
288+
if isempty(A) || length(A) == 1 && corrected
307289
fill!(R, convert(S, NaN))
308290
else
309-
rn = div(length(A), length(R)) - Int(corrected)
291+
rn = prod(ntuple(d->size(R, d) == 1 ? size(A, d) : 1, Val(max(ndims(A), ndims(R))))) - Int(corrected)
310292
centralize_sumabs2!(R, A, m)
311293
R .= R .* (1 // rn)
312294
end
@@ -339,15 +321,33 @@ over dimensions. In that case, `mean` must be an array with the same shape as
339321
"""
340322
varm(A::AbstractArray, m::AbstractArray; corrected::Bool=true, dims=:) = _varm(A, m, corrected, dims)
341323

342-
_varm(A::AbstractArray{T}, m, corrected::Bool, region) where {T} =
343-
varm!(Base.reducedim_init(t -> abs2(t)/2, +, A, region), A, m; corrected=corrected)
324+
_throw_mean_mismatch(A, m, region) = throw(DimensionMismatch("axes of means ($(axes(m))) does not match reduction over $(region) of $(axes(A))"))
325+
function _checkm(A::AbstractArray, m::AbstractArray, region)
326+
for d in 1:max(ndims(A), ndims(m))
327+
if d in region
328+
size(m, d) == 1 || _throw_mean_mismatch(A, m, region)
329+
else
330+
axes(m, d) == axes(A, d) || _throw_mean_mismatch(A, m, region)
331+
end
332+
end
333+
end
334+
function _varm(A::AbstractArray, m, corrected::Bool, region)
335+
_checkm(A, m, region)
336+
rn = prod(ntuple(d->d in region ? size(A, d) : 1, Val(ndims(A)))) - Int(corrected)
337+
R = centralize_sumabs2(A, m, region)
338+
if rn <= 0
339+
R .= R ./ 0
340+
else
341+
R .= R .* 1//rn # why use Rational?
342+
end
343+
return R
344+
end
344345

345346
varm(A::AbstractArray, m; corrected::Bool=true) = _varm(A, m, corrected, :)
346347

347348
function _varm(A::AbstractArray{T}, m, corrected::Bool, ::Colon) where T
348-
n = length(A)
349-
n == 0 && return oftype((abs2(zero(T)) + abs2(zero(T)))/2, NaN)
350-
return centralize_sumabs2(A, m) / (n - Int(corrected))
349+
rn = max(length(A) - Int(corrected), 0)
350+
centralize_sumabs2(A, m)/rn
351351
end
352352

353353

test/runtests.jl

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -498,9 +498,15 @@ Y = [6.0 2.0;
498498
@testset "cov with missing" begin
499499
@test cov([missing]) === cov([1, missing]) === missing
500500
@test cov([1, missing], [2, 3]) === cov([1, 3], [2, missing]) === missing
501-
@test_throws Exception cov([1 missing; 2 3])
502-
@test_throws Exception cov([1 missing; 2 3], [1, 2])
503-
@test_throws Exception cov([1, 2], [1 missing; 2 3])
501+
if isdefined(Base, :_reducedim_init)
502+
@test_broken isequal(coalesce.(cov([1 missing; 2 3]), NaN), cov([1 NaN; 2 3]))
503+
@test_broken isequal(coalesce.(cov([1 missing; 2 3], [1, 2]), NaN), cov([1 NaN; 2 3], [1, 2]))
504+
@test_broken isequal(coalesce.(cov([1, 2], [1 missing; 2 3]), NaN), cov([1, 2], [1 NaN; 2 3]))
505+
else
506+
@test isequal(coalesce.(cov([1 missing; 2 3]), NaN), cov([1 NaN; 2 3]))
507+
@test isequal(coalesce.(cov([1 missing; 2 3], [1, 2]), NaN), cov([1 NaN; 2 3], [1, 2]))
508+
@test isequal(coalesce.(cov([1, 2], [1 missing; 2 3]), NaN), cov([1, 2], [1 NaN; 2 3]))
509+
end
504510
@test isequal(cov([1 2; 2 3], [1, missing]), [missing missing]')
505511
@test isequal(cov([1, missing], [1 2; 2 3]), [missing missing])
506512
end
@@ -609,9 +615,15 @@ end
609615
@test cor([missing]) === missing
610616
@test cor([1, missing]) == 1
611617
@test cor([1, missing], [2, 3]) === cor([1, 3], [2, missing]) === missing
612-
@test_throws Exception cor([1 missing; 2 3])
613-
@test_throws Exception cor([1 missing; 2 3], [1, 2])
614-
@test_throws Exception cor([1, 2], [1 missing; 2 3])
618+
if isdefined(Base, :_reducedim_init)
619+
@test_broken isequal(coalesce.(cor([1 missing; 2 3]), NaN), cor([1 NaN; 2 3]))
620+
@test_broken isequal(coalesce.(cor([1 missing; 2 3], [1, 2]), NaN), cor([1 NaN; 2 3], [1, 2]))
621+
@test_broken isequal(coalesce.(cor([1, 2], [1 missing; 2 3]), NaN), cor([1, 2], [1 NaN; 2 3]))
622+
else
623+
@test isequal(coalesce.(cor([1 missing; 2 3]), NaN), cor([1 NaN; 2 3]))
624+
@test isequal(coalesce.(cor([1 missing; 2 3], [1, 2]), NaN), cor([1 NaN; 2 3], [1, 2]))
625+
@test isequal(coalesce.(cor([1, 2], [1 missing; 2 3]), NaN), cor([1, 2], [1 NaN; 2 3]))
626+
end
615627
@test isequal(cor([1 2; 2 3], [1, missing]), [missing missing]')
616628
@test isequal(cor([1, missing], [1 2; 2 3]), [missing missing])
617629
end
@@ -1030,3 +1042,19 @@ end
10301042
@test isequal(cov(Int[], my), fill(-0.0, 1, 3))
10311043
@test isequal(cor(Int[], my), fill(NaN, 1, 3))
10321044
end
1045+
1046+
@testset "mean, var, std type stability with Missings; Issue #160" begin
1047+
@test (@inferred Missing mean(view([1, 2, missing], 1:2))) == (@inferred mean([1,2]))
1048+
@test (@inferred Missing var(view([1, 2, missing], 1:2))) == (@inferred var([1,2]))
1049+
@test (@inferred Missing std(view([1, 2, missing], 1:2))) == (@inferred std([1,2]))
1050+
end
1051+
1052+
@testset "inexact errors; Issues #7 and #126" begin
1053+
a = [missing missing; 0 1]
1054+
@test isequal(mean(a;dims=2), [missing; 0.5;;])
1055+
1056+
x = [(i==3 && j==3) ? missing : i*j for i in 1:3, j in 1:4]
1057+
@test ismissing(@inferred Float64 mean(x))
1058+
@test isequal(mean(x; dims=1), [2. 4. missing 8.])
1059+
@test isequal(mean(x; dims=2), [2.5; 5.0; missing;;])
1060+
end

0 commit comments

Comments
 (0)