Skip to content

Commit 301db97

Browse files
authored
implement count and count! using mapreduce (#34048)
This creates the same calling interface for `count` as for other mapreduce-type functions like e.g. `sum`, namely allowing the `dims` keyword. The implementation itself is shorter than before without sacrificing performance. More detailed documentation for `count` was added too.
1 parent dddda07 commit 301db97

File tree

5 files changed

+89
-14
lines changed

5 files changed

+89
-14
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ New library features
146146
will acquire locks for safe multi-threaded access. Setting it to `false` provides better
147147
performance when only one thread will access the file.
148148
* The introspection macros (`@which`, `@code_typed`, etc.) now work with `do`-block syntax ([#35283]) and with dot syntax ([#35522]).
149+
* `count` now accepts the `dims` keyword.
150+
* new in-place `count!` function similar to `sum!`.
149151

150152
Standard library changes
151153
------------------------

base/exports.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ export
486486
any,
487487
firstindex,
488488
collect,
489+
count!,
489490
count,
490491
delete!,
491492
deleteat!,

base/reduce.jl

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,8 @@ end
836836

837837
## count
838838

839+
_bool(f::Function) = x->f(x)::Bool
840+
839841
"""
840842
count(p, itr) -> Integer
841843
count(itr) -> Integer
@@ -853,22 +855,10 @@ julia> count([true, false, true, true])
853855
3
854856
```
855857
"""
856-
function count(pred, itr)
857-
n = 0
858-
for x in itr
859-
n += pred(x)::Bool
860-
end
861-
return n
862-
end
863-
function count(pred, a::AbstractArrayOrBroadcasted)
864-
n = 0
865-
for i in eachindex(a)
866-
@inbounds n += pred(a[i])::Bool
867-
end
868-
return n
869-
end
870858
count(itr) = count(identity, itr)
871859

860+
count(f, itr) = mapreduce(_bool(f), add_sum, itr, init=0)
861+
872862
function count(::typeof(identity), x::Array{Bool})
873863
n = 0
874864
chunks = length(x) ÷ sizeof(UInt)

base/reducedim.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,67 @@ julia> reduce(max, a, dims=1)
359359
reduce(op, A::AbstractArray; kw...) = mapreduce(identity, op, A; kw...)
360360

361361
##### Specific reduction functions #####
362+
363+
"""
364+
count([f=identity,] A::AbstractArray; dims=:)
365+
366+
Count the number of elements in `A` for which `f` returns `true` over the given
367+
dimensions.
368+
369+
!!! compat "Julia 1.5"
370+
`dims` keyword was added in Julia 1.5.
371+
372+
# Examples
373+
```jldoctest
374+
julia> A = [1 2; 3 4]
375+
2×2 Array{Int64,2}:
376+
1 2
377+
3 4
378+
379+
julia> count(<=(2), A, dims=1)
380+
1×2 Array{Int64,2}:
381+
1 1
382+
383+
julia> count(<=(2), A, dims=2)
384+
2×1 Array{Int64,2}:
385+
2
386+
0
387+
```
388+
"""
389+
count(A::AbstractArrayOrBroadcasted; dims=:) = count(identity, A, dims=dims)
390+
count(f, A::AbstractArrayOrBroadcasted; dims=:) = mapreduce(_bool(f), add_sum, A, dims=dims, init=0)
391+
392+
"""
393+
count!([f=identity,] r, A; init=true)
394+
395+
Count the number of elements in `A` for which `f` returns `true` over the
396+
singleton dimensions of `r`, writing the result into `r` in-place.
397+
If `init` is `true`, values in `r` are initialized to zero.
398+
399+
!!! compat "Julia 1.5"
400+
inplace `count!` was added in Julia 1.5.
401+
402+
# Examples
403+
```jldoctest
404+
julia> A = [1 2; 3 4]
405+
2×2 Array{Int64,2}:
406+
1 2
407+
3 4
408+
409+
julia> count!(<=(2), [1 1], A)
410+
1×2 Array{Int64,2}:
411+
1 1
412+
413+
julia> count!(<=(2), [1; 1], A)
414+
2-element Array{Int64,1}:
415+
2
416+
0
417+
```
418+
"""
419+
count!(r::AbstractArray, A::AbstractArrayOrBroadcasted; init::Bool=true) = count!(identity, r, A; init=init)
420+
count!(f, r::AbstractArray, A::AbstractArrayOrBroadcasted; init::Bool=true) =
421+
mapreducedim!(_bool(f), add_sum, initarray!(r, add_sum, init, A), A)
422+
362423
"""
363424
sum(A::AbstractArray; dims)
364425

test/reducedim.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ safe_sum(A::Array{T}, region) where {T} = safe_mapslices(sum, A, region)
1212
safe_prod(A::Array{T}, region) where {T} = safe_mapslices(prod, A, region)
1313
safe_maximum(A::Array{T}, region) where {T} = safe_mapslices(maximum, A, region)
1414
safe_minimum(A::Array{T}, region) where {T} = safe_mapslices(minimum, A, region)
15+
safe_count(A::AbstractArray{T}, region) where {T} = safe_mapslices(count, A, region)
1516
safe_sumabs(A::Array{T}, region) where {T} = safe_mapslices(sum, abs.(A), region)
1617
safe_sumabs2(A::Array{T}, region) where {T} = safe_mapslices(sum, abs2.(A), region)
1718
safe_maxabs(A::Array{T}, region) where {T} = safe_mapslices(maximum, abs.(A), region)
@@ -21,15 +22,21 @@ safe_minabs(A::Array{T}, region) where {T} = safe_mapslices(minimum, abs.(A), re
2122
1, 2, 3, 4, 5, (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4),
2223
(1, 2, 3), (1, 3, 4), (2, 3, 4), (1, 2, 3, 4)]
2324
Areduc = rand(3, 4, 5, 6)
25+
Breduc = rand(Bool, 3, 4, 5, 6)
26+
@assert axes(Areduc) == axes(Breduc)
27+
2428
r = fill(NaN, map(length, Base.reduced_indices(axes(Areduc), region)))
2529
@test sum!(r, Areduc) safe_sum(Areduc, region)
2630
@test prod!(r, Areduc) safe_prod(Areduc, region)
2731
@test maximum!(r, Areduc) safe_maximum(Areduc, region)
2832
@test minimum!(r, Areduc) safe_minimum(Areduc, region)
33+
@test count!(r, Breduc) safe_count(Breduc, region)
34+
2935
@test sum!(abs, r, Areduc) safe_sumabs(Areduc, region)
3036
@test sum!(abs2, r, Areduc) safe_sumabs2(Areduc, region)
3137
@test maximum!(abs, r, Areduc) safe_maxabs(Areduc, region)
3238
@test minimum!(abs, r, Areduc) safe_minabs(Areduc, region)
39+
@test count!(!, r, Breduc) safe_count(.!Breduc, region)
3340

3441
# With init=false
3542
r2 = similar(r)
@@ -41,6 +48,9 @@ safe_minabs(A::Array{T}, region) where {T} = safe_mapslices(minimum, abs.(A), re
4148
@test maximum!(r, Areduc, init=false) fill!(r2, 1.8)
4249
fill!(r, -0.2)
4350
@test minimum!(r, Areduc, init=false) fill!(r2, -0.2)
51+
fill!(r, 1)
52+
@test count!(r, Breduc, init=false) safe_count(Breduc, region) .+ 1
53+
4454
fill!(r, 8.1)
4555
@test sum!(abs, r, Areduc, init=false) safe_sumabs(Areduc, region) .+ 8.1
4656
fill!(r, 8.1)
@@ -49,15 +59,20 @@ safe_minabs(A::Array{T}, region) where {T} = safe_mapslices(minimum, abs.(A), re
4959
@test maximum!(abs, r, Areduc, init=false) fill!(r2, 1.5)
5060
fill!(r, -1.5)
5161
@test minimum!(abs, r, Areduc, init=false) fill!(r2, -1.5)
62+
fill!(r, 1)
63+
@test count!(!, r, Breduc, init=false) safe_count(.!Breduc, region) .+ 1
5264

5365
@test @inferred(sum(Areduc, dims=region)) safe_sum(Areduc, region)
5466
@test @inferred(prod(Areduc, dims=region)) safe_prod(Areduc, region)
5567
@test @inferred(maximum(Areduc, dims=region)) safe_maximum(Areduc, region)
5668
@test @inferred(minimum(Areduc, dims=region)) safe_minimum(Areduc, region)
69+
@test @inferred(count(Breduc, dims=region)) safe_count(Breduc, region)
70+
5771
@test @inferred(sum(abs, Areduc, dims=region)) safe_sumabs(Areduc, region)
5872
@test @inferred(sum(abs2, Areduc, dims=region)) safe_sumabs2(Areduc, region)
5973
@test @inferred(maximum(abs, Areduc, dims=region)) safe_maxabs(Areduc, region)
6074
@test @inferred(minimum(abs, Areduc, dims=region)) safe_minabs(Areduc, region)
75+
@test @inferred(count(!, Breduc, dims=region)) safe_count(.!Breduc, region)
6176
end
6277

6378
# Test reduction along first dimension; this is special-cased for
@@ -416,3 +431,9 @@ end
416431

417432
@test sum([Variable(:x), Variable(:y)], dims=1) == [AffExpr([Variable(:x), Variable(:y)])]
418433
end
434+
435+
# count
436+
@testset "count: throw on non-bool types" begin
437+
@test_throws TypeError count([1], dims=1)
438+
@test_throws TypeError count!([1], [1])
439+
end

0 commit comments

Comments
 (0)