@@ -43,6 +43,18 @@ julia> mean(skipmissing([1, missing, 3]))
43
43
"""
44
44
mean (itr) = mean (identity, itr)
45
45
46
+ struct Counter{F} <: Function
47
+ f:: F
48
+ n:: Base.RefValue{Int}
49
+ end
50
+ Counter (f:: F ) where {F} = Counter {F} (f, Ref (0 ))
51
+ (f:: Counter )(x) = (f. n[] += 1 ; f. f (x))
52
+
53
+ struct DivOne{F} <: Function
54
+ f:: F
55
+ end
56
+ (f:: DivOne )(x) = f. f (x)/ 1
57
+
46
58
"""
47
59
mean(f, itr)
48
60
@@ -59,23 +71,13 @@ julia> mean([√1, √2, √3])
59
71
```
60
72
"""
61
73
function mean (f, itr)
62
- y = iterate (itr)
63
- if y === nothing
64
- return Base. mapreduce_empty_iter (f, + , itr,
65
- Base. IteratorEltype (itr)) / 0
66
- end
67
- count = 1
68
- value, state = y
69
- f_value = f (value)/ 1
70
- total = Base. reduce_first (+ , f_value)
71
- y = iterate (itr, state)
72
- while y != = nothing
73
- value, state = y
74
- total += _mean_promote (total, f (value))
75
- count += 1
76
- y = iterate (itr, state)
74
+ if Base. IteratorSize (itr) === Base. SizeUnknown ()
75
+ g = Counter (DivOne (f))
76
+ result = mapfoldl (g, add_mean, itr)
77
+ return result/ g. n[]
78
+ else
79
+ return mapfoldl (DivOne (f), add_mean, itr)/ length (itr)
77
80
end
78
- return total/ count
79
81
end
80
82
81
83
"""
@@ -180,20 +182,24 @@ mean(A::AbstractArray; dims=:) = _mean(identity, A, dims)
180
182
181
183
_mean_promote (x:: T , y:: S ) where {T,S} = convert (promote_type (T, S), y)
182
184
185
+ add_mean (x, y) = Base. add_sum (x, _mean_promote (x, y))
186
+
187
+ Base. reduce_empty (:: typeof (add_mean), T) = Base. reduce_empty (Base. add_sum, T)
188
+ Base. mapreduce_empty (g:: DivOne , :: typeof (add_mean), T) = Base. mapreduce_empty (g. f, Base. add_sum, T)/ 1
189
+ Base. mapreduce_empty (g:: Counter{<:DivOne} , :: typeof (add_mean), T) = Base. mapreduce_empty (g. f. f, Base. add_sum, T)/ 1
190
+
191
+
183
192
# ::Dims is there to force specializing on Colon (as it is a Function)
184
193
function _mean (f, A:: AbstractArray , dims:: Dims = :) where Dims
185
- isempty (A) && return sum (f, A, dims= dims)/ 0
186
194
if dims === (:)
195
+ result = mapreduce (DivOne (f), add_mean, A, dims= dims)
187
196
n = length (A)
188
- else
189
- n = mapreduce (i -> size (A, i), * , unique (dims); init= 1 )
190
- end
191
- x1 = f (first (A)) / 1
192
- result = sum (x -> _mean_promote (x1, f (x)), A, dims= dims)
193
- if dims === (:)
194
197
return result / n
195
198
else
196
- return result ./= n
199
+ result = mapreduce (DivOne (f), add_mean, A, dims= dims)
200
+ n = prod (i -> size (A, i), unique (dims); init= 1 )
201
+ result ./= n
202
+ return result
197
203
end
198
204
end
199
205
@@ -211,6 +217,7 @@ realXcY(x::Complex, y::Complex) = real(x)*real(y) + imag(x)*imag(y)
211
217
var (iterable; corrected:: Bool = true , mean= nothing ) = _var (iterable, corrected, mean)
212
218
213
219
function _var (iterable, corrected:: Bool , mean)
220
+ ismissing (mean) && return missing
214
221
y = iterate (iterable)
215
222
if y === nothing
216
223
T = eltype (iterable)
@@ -252,61 +259,36 @@ function _var(iterable, corrected::Bool, mean)
252
259
end
253
260
end
254
261
255
- centralizedabs2fun (m) = x -> abs2 .(x - m)
256
- centralize_sumabs2 (A:: AbstractArray , m) =
257
- mapreduce (centralizedabs2fun (m), + , A)
258
- centralize_sumabs2 (A:: AbstractArray , m, ifirst:: Int , ilast:: Int ) =
259
- Base. mapreduce_impl (centralizedabs2fun (m), + , A, ifirst, ilast)
260
-
261
- function centralize_sumabs2! (R:: AbstractArray{S} , A:: AbstractArray , means:: AbstractArray ) where S
262
- # following the implementation of _mapreducedim! at base/reducedim.jl
263
- lsiz = Base. check_reducedims (R,A)
264
- for i in 1 : max (ndims (R), ndims (means))
265
- if axes (means, i) != axes (R, i)
266
- throw (DimensionMismatch (" dimension $i of `mean` should have indices $(axes (R, i)) , but got $(axes (means, i)) " ))
267
- end
268
- end
269
- isempty (R) || fill! (R, zero (S))
270
- isempty (A) && return R
271
-
272
- if Base. has_fast_linear_indexing (A) && lsiz > 16 && ! has_offset_axes (R, means)
273
- nslices = div (length (A), lsiz)
274
- ibase = first (LinearIndices (A))- 1
275
- for i = 1 : nslices
276
- @inbounds R[i] = centralize_sumabs2 (A, means[i], ibase+ 1 , ibase+ lsiz)
277
- ibase += lsiz
278
- end
279
- return R
280
- end
281
- indsAt, indsRt = Base. safe_tail (axes (A)), Base. safe_tail (axes (R)) # handle d=1 manually
282
- keep, Idefault = Broadcast. shapeindexer (indsRt)
283
- if Base. reducedim1 (R, A)
284
- i1 = first (Base. axes1 (R))
285
- @inbounds for IA in CartesianIndices (indsAt)
286
- IR = Broadcast. newindex (IA, keep, Idefault)
287
- r = R[i1,IR]
288
- m = means[i1,IR]
289
- @simd for i in axes (A, 1 )
290
- r += abs2 (A[i,IA] - m)
291
- end
292
- R[i1,IR] = r
293
- end
294
- else
295
- @inbounds for IA in CartesianIndices (indsAt)
296
- IR = Broadcast. newindex (IA, keep, Idefault)
297
- @simd for i in axes (A, 1 )
298
- R[i,IR] += abs2 (A[i,IA] - means[i,IR])
299
- end
300
- end
301
- end
302
- return R
262
+ struct CentralizedAbs2Fun{T,S} <: Function
263
+ mean:: S
303
264
end
265
+ CentralizedAbs2Fun {T} (means) where {T} = CentralizedAbs2Fun {T,typeof(means)} (means)
266
+ CentralizedAbs2Fun (means) = CentralizedAbs2Fun {typeof(means)} (means)
267
+ CentralizedAbs2Fun (means, extrude) = CentralizedAbs2Fun {eltype(means)} (Broadcast. extrude (means))
268
+ # Division is generally costly, but Julia is typically able to constant propagate a /1
269
+ # and simply ensure we get the type right at no cost, allowing the division in-place later
270
+ (f:: CentralizedAbs2Fun )(x) = abs2 .(x - f. mean)/ 1
271
+ (f:: CentralizedAbs2Fun{<:Any,<:Broadcast.Extruded} )((i, x),) = abs2 .(x - Broadcast. _broadcast_getindex (f. mean, i))/ 1
272
+ _doubled (x) = x+ x
273
+ Base. mapreduce_empty (:: CentralizedAbs2Fun{T,<:Broadcast.Extruded} , :: typeof (Base. add_sum), :: Type{Tuple{_Any,S}} ) where {T<: Number , S<: Number , _Any} = _doubled (abs2 (zero (T)- zero (S)))/ 1
274
+ Base. mapreduce_empty (:: CentralizedAbs2Fun{T,<:Broadcast.Extruded} , :: typeof (Base. add_sum), :: Type{Tuple{_Any, Union{Missing, S}}} ) where {T<: Number , S<: Number , _Any} = _doubled (abs2 (zero (T)- zero (S)))/ 1
275
+ Base. mapreduce_empty (:: CentralizedAbs2Fun{T} , :: typeof (Base. add_sum), :: Type{S} ) where {T<: Number , S<: Number } = _doubled (abs2 (zero (T)- zero (S)))/ 1
276
+ Base. mapreduce_empty (:: CentralizedAbs2Fun{T} , :: typeof (Base. add_sum), :: Type{Union{Missing, S}} ) where {T<: Number , S<: Number } = _doubled (abs2 (zero (T)- zero (S)))/ 1
277
+
278
+ centralize_sumabs2 (A:: AbstractArray , m) =
279
+ sum (CentralizedAbs2Fun (m), A)
280
+ centralize_sumabs2 (A:: AbstractArray , m:: AbstractArray , region) =
281
+ sum (CentralizedAbs2Fun (m, true ), Base. PairsArray (A), dims= region)
282
+ centralize_sumabs2! (R:: AbstractArray , A:: AbstractArray , means:: AbstractArray ) =
283
+ sum! (CentralizedAbs2Fun (means, true ), R, Base. PairsArray (A))
284
+
304
285
305
286
function varm! (R:: AbstractArray{S} , A:: AbstractArray , m:: AbstractArray ; corrected:: Bool = true ) where S
306
- if isempty (A)
287
+ _checkm (R, m, ntuple (identity, Val (max (ndims (R), ndims (m)))))
288
+ if isempty (A) || length (A) == 1 && corrected
307
289
fill! (R, convert (S, NaN ))
308
290
else
309
- rn = div ( length (A), length (R )) - Int (corrected)
291
+ rn = prod ( ntuple (d -> size (R, d) == 1 ? size (A, d) : 1 , Val ( max ( ndims (A), ndims (R))) )) - Int (corrected)
310
292
centralize_sumabs2! (R, A, m)
311
293
R .= R .* (1 // rn)
312
294
end
@@ -339,15 +321,33 @@ over dimensions. In that case, `mean` must be an array with the same shape as
339
321
"""
340
322
varm (A:: AbstractArray , m:: AbstractArray ; corrected:: Bool = true , dims= :) = _varm (A, m, corrected, dims)
341
323
342
- _varm (A:: AbstractArray{T} , m, corrected:: Bool , region) where {T} =
343
- varm! (Base. reducedim_init (t -> abs2 (t)/ 2 , + , A, region), A, m; corrected= corrected)
324
+ _throw_mean_mismatch (A, m, region) = throw (DimensionMismatch (" axes of means ($(axes (m)) ) does not match reduction over $(region) of $(axes (A)) " ))
325
+ function _checkm (A:: AbstractArray , m:: AbstractArray , region)
326
+ for d in 1 : max (ndims (A), ndims (m))
327
+ if d in region
328
+ size (m, d) == 1 || _throw_mean_mismatch (A, m, region)
329
+ else
330
+ axes (m, d) == axes (A, d) || _throw_mean_mismatch (A, m, region)
331
+ end
332
+ end
333
+ end
334
+ function _varm (A:: AbstractArray , m, corrected:: Bool , region)
335
+ _checkm (A, m, region)
336
+ rn = prod (ntuple (d-> d in region ? size (A, d) : 1 , Val (ndims (A)))) - Int (corrected)
337
+ R = centralize_sumabs2 (A, m, region)
338
+ if rn <= 0
339
+ R .= R ./ 0
340
+ else
341
+ R .= R .* 1 // rn # why use Rational?
342
+ end
343
+ return R
344
+ end
344
345
345
346
varm (A:: AbstractArray , m; corrected:: Bool = true ) = _varm (A, m, corrected, :)
346
347
347
348
function _varm (A:: AbstractArray{T} , m, corrected:: Bool , :: Colon ) where T
348
- n = length (A)
349
- n == 0 && return oftype ((abs2 (zero (T)) + abs2 (zero (T)))/ 2 , NaN )
350
- return centralize_sumabs2 (A, m) / (n - Int (corrected))
349
+ rn = max (length (A) - Int (corrected), 0 )
350
+ centralize_sumabs2 (A, m)/ rn
351
351
end
352
352
353
353
0 commit comments