diff --git a/Project.toml b/Project.toml index 13264cc3..f4ee73e4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FillArrays" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.7.0" +version = "1.7.1" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 7ca7b78e..23aa6279 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -550,16 +550,6 @@ for SMT in (:Diagonal, :Bidiagonal, :Tridiagonal, :SymTridiagonal) end end - -######### -# maximum/minimum -######### - -for op in (:maximum, :minimum) - @eval $op(x::AbstractFill) = getindex_value(x) -end - - ######### # Cumsum ######### diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 6cf11284..36c4ac3d 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -33,22 +33,81 @@ end ### mapreduce function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, ::Colon) - fval = f(getindex_value(A)) - out = fval - for _ in 2:length(A) - out = op(out, fval) + if length(A) == 0 + return Base.mapreduce_empty_iter(f, op, A, Base.HasEltype()) + end + val = getindex_value(A) + out = Base.mapreduce_first(f, op, val) + fval = f(val) + if op(out, fval) != out + for _ in 2:length(A) + out = op(out, fval) + end + end + out +end + +function Base._mapreduce_dim(f, op, init, A::AbstractFill, ::Colon) + if length(A) == 0 + return init + end + val = getindex_value(A) + fval = f(val) + out = op(init, fval) + if op(out, fval) != out + for _ in 2:length(A) + out = op(out, fval) + end end out end +identityel(f, ::Union{typeof(+), typeof(Base.add_sum)}, A) = zero(f(zero(eltype(A)))) +identityel(f, ::Union{typeof(*), typeof(Base.mul_prod)}, A) = one(f(one(eltype(A)))) +identityel(f, ::typeof(&), A) = true +identityel(f, ::typeof(|), A) = false +identityel(f, ::Any, @nospecialize(A)) = throw(ArgumentError("reducing over an empty collection is not allowed")) +function mapreducedim_empty(f, op, A) + z = identityel(f, op, A) + op(z, z) +end + +function reduced_indices(A, dims) + ntuple(d -> d in dims ? axes(A,ndims(A)+1) : axes(A,d), ndims(A)) +end + function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, dims) - fval = f(getindex_value(A)) red = *(ntuple(d -> d in dims ? size(A,d) : 1, ndims(A))...) - out = fval - for _ in 2:red - out = op(out, fval) + if red == 0 + out = mapreducedim_empty(f, op, A) + else + val = getindex_value(A) + out = Base.mapreduce_first(f, op, val) + fval = f(val) + if op(out, fval) != out + for _ in 2:red + out = op(out, fval) + end + end + end + Fill(out, reduced_indices(A, dims)) +end + +function Base._mapreduce_dim(f, op, init, A::AbstractFill, dims) + red = *(ntuple(d -> d in dims ? size(A,d) : 1, ndims(A))...) + if red == 0 + out = init + else + val = getindex_value(A) + fval = f(val) + out = op(init, fval) + if op(out, fval) != out + for _ in 2:red + out = op(out, fval) + end + end end - Fill(out, ntuple(d -> d in dims ? Base.OneTo(1) : axes(A,d), ndims(A))) + Fill(out, reduced_indices(A, dims)) end function mapreduce(f, op, A::AbstractFill, B::AbstractFill; kw...) diff --git a/test/runtests.jl b/test/runtests.jl index 50d42c54..2bacc8c6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1090,7 +1090,7 @@ end @test Zeros(S, 10) .* (T(1):T(10)) ≡ Zeros(U, 10) @test_throws DimensionMismatch Zeros(S, 10) .* (T(1):T(11)) end - end + end end end @@ -1115,6 +1115,62 @@ end end @testset "mapreduce" begin + @testset "corner cases with small arrays" begin + @test_throws Exception reduce(max, Fill(2,0)) + @test_throws Exception reduce(max, Fill(2,0), dims=1) + function testreduce(op, A; kw...) + B = Array(A) + F = reduce(op, A; kw...) + @test F == reduce(op, B; kw...) + if haskey(kw, :dims) + @test F isa Fill + end + if !isempty(A) + @test mapreduce(x->x^2, op, A; kw...) == mapreduce(x->x^2, op, B; kw...) + end + end + @testset for (op, init) in ((+, 0), (*, 1)) + testreduce(op, Fill(2,0)) + testreduce(op, Fill(2,0); init) + testreduce(op, Fill(2,0), dims=1) + testreduce(op, Fill(2,0); init, dims=1) + end + @testset for (op, init) in ((&, true), (|, false)) + testreduce(op, Fill(true,0)) + testreduce(op, Fill(true,0); init) + testreduce(op, Fill(true,0), dims=1) + testreduce(op, Fill(true,0); init, dims=1) + end + testreduce(vcat, Fill(2,0), init=Int[]) + testreduce(vcat, Fill(2,1), init=Int[]) + @testset for (op, init) in ((max, 0), (+, 0), (*, 1)) + testreduce(op, Fill(2,0), dims=2) + testreduce(op, Fill(2,0,1), dims=2) + testreduce(op, Fill(2,0); init, dims=2) + testreduce(op, Fill(2,0,1); init, dims=2) + + testreduce(op, Fill(2,1)) + testreduce(op, Fill(2,1), dims=1) + testreduce(op, Fill(2,1), dims=2) + testreduce(op, Fill(2,1,1), dims=1) + testreduce(op, Fill(2,1,1), dims=2) + testreduce(op, Fill(2,1); init, dims=1) + testreduce(op, Fill(2,1); init, dims=2) + testreduce(op, Fill(2,1,1); init, dims=1) + testreduce(op, Fill(2,1,1); init, dims=2) + + testreduce(op, Fill(2,2)) + testreduce(op, Fill(2,2), dims=1) + testreduce(op, Fill(2,2), dims=2) + testreduce(op, Fill(2,2,2), dims=1) + testreduce(op, Fill(2,2,2), dims=2) + testreduce(op, Fill(2,2); init, dims=1) + testreduce(op, Fill(2,2); init, dims=2) + testreduce(op, Fill(2,2,2); init, dims=1) + testreduce(op, Fill(2,2,2); init, dims=2) + end + end + x = rand(3, 4) y = fill(1.0, 3, 4) Y = Fill(1.0, 3, 4)