diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 66dad480..615f1a49 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -7,7 +7,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, any, all, axes, isone, iszero, iterate, unique, allunique, permutedims, inv, copy, vec, setindex!, count, ==, reshape, map, zero, show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat, - parent, similar, issorted, add_sum, accumulate, OneTo, permutedims + parent, similar, issorted, add_sum, accumulate, OneTo, permutedims, mul_prod import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec, @@ -591,22 +591,11 @@ for SMT in (:Diagonal, :Bidiagonal, :Tridiagonal, :SymTridiagonal) end -######### -# maximum/minimum -######### - -for op in (:maximum, :minimum) - @eval $op(x::AbstractFill) = getindex_value(x) -end - - ######### # Cumsum ######### # These methods are necessary to deal with infinite arrays -sum(x::AbstractFill) = getindex_value(x)*length(x) -sum(f, x::AbstractFill) = length(x) * f(getindex_value(x)) sum(x::AbstractZeros) = getindex_value(x) # needed to support infinite case diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 2b5ea59c..f77dc1a1 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -31,24 +31,28 @@ function _maplinear(rs...) # tries to match Base's behaviour, could perhaps hook end ### mapreduce - -function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, ::Colon) - fval = f(getindex_value(A)) - out = fval - for _ in 2:length(A) - out = op(out, fval) +@inline red(A, dims) = *(ntuple(d -> d in dims ? size(A,d) : 1, ndims(A))...) +@inline outdim(A, dims) = ntuple(d -> d in dims ? Base.OneTo(1) : axes(A,d), ndims(A)) +@inline function iterfun(f, n, val, nt=val) + out = nt + for _ in 1:n + out = f(out, val) end out end +Base._mapreduce_dim(f, op, nt, A::AbstractFill, ::Colon) = iterfun(op, length(A), f(getindex_value(A)), nt) +Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, ::Colon) = iterfun(op, length(A)-1, f(getindex_value(A))) +Base._mapreduce_dim(f, op, nt, A::AbstractFill, dims) = Fill(iterfun(op, red(A, dims), f(getindex_value(A)), nt), outdim(A, dims)) +Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, dims) = Fill(iterfun(op, red(A, dims)-1, f(getindex_value(A))), outdim(A, dims)) -function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, dims) - fval = f(getindex_value(A)) - red = *(ntuple(d -> d in dims ? size(A,d) : 1, ndims(A))...) - out = fval - for _ in 2:red - out = op(out, fval) +firstval(a, b) = a +for (op, iterop) in ((:+, :*), (:*, :^), (:add_sum, :mul_prod), (:mul_prod, :^), (:max, :firstval), (:min, :firstval), (:|, :firstval), (:&, :firstval)) + @eval begin + Base._mapreduce_dim(f, ::typeof($op), nt, A::AbstractFill, dims) = Fill($op(nt, $iterop(f(getindex_value(A)), red(A, dims))), outdim(A, dims)) + Base._mapreduce_dim(f, ::typeof($op), ::Base._InitialValue, A::AbstractFill, dims) = Fill(($iterop)(f(getindex_value(A)), red(A, dims)), outdim(A, dims)) + Base._mapreduce_dim(f, ::typeof($op), nt, A::AbstractFill, ::Colon) = $op(nt, ($iterop)(f(getindex_value(A)), length(A))) + Base._mapreduce_dim(f, ::typeof($op), ::Base._InitialValue, A::AbstractFill, ::Colon) = ($iterop)(f(getindex_value(A)), length(A)) end - Fill(out, ntuple(d -> d in dims ? Base.OneTo(1) : axes(A,d), ndims(A))) end function mapreduce(f, op, A::AbstractFill, B::AbstractFill; kw...) diff --git a/test/runtests.jl b/test/runtests.jl index 088089d4..4a9702b3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1210,12 +1210,18 @@ end Y = Fill(1.0, 3, 4) O = Ones(3, 4) + op2(x,y) = x^2 + 3y + @test mapreduce(exp, +, Y) == mapreduce(exp, +, y) @test mapreduce(exp, +, Y; dims=2) == mapreduce(exp, +, y; dims=2) @test mapreduce(identity, +, Y) == sum(y) == sum(Y) @test mapreduce(identity, +, Y, dims=1) == sum(y, dims=1) == sum(Y, dims=1) - @test mapreduce(exp, +, Y; dims=(1,), init=5.0) == mapreduce(exp, +, y; dims=(1,), init=5.0) + @test isapprox(mapreduce(exp, +, Y; dims=(1,), init=5.0), mapreduce(exp, +, y; dims=(1,), init=5.0), rtol=eps()) + @test mapreduce(exp, op2, Y; dims=(1,), init=5.0) == mapreduce(exp, op2, y; dims=(1,), init=5.0) + @test mapreduce(exp, op2, Y; dims=(1,)) == [mapreduce(exp, op2, y[:,k]) for k in 1:4]' # see https://github.com/JuliaLang/julia/issues/52188 + @test mapreduce(exp, op2, Y; init=5.0) == mapreduce(exp, op2, y; init=5.0) + @test mapreduce(exp, op2, Y) == mapreduce(exp, op2, y) # Two arrays @test mapreduce(*, +, x, Y) == mapreduce(*, +, x, y) @@ -1224,7 +1230,6 @@ end @test mapreduce(*, +, Y, O) == mapreduce(*, +, y, y) f2(x,y) = 1 + x/y - op2(x,y) = x^2 + 3y @test mapreduce(f2, op2, x, Y) == mapreduce(f2, op2, x, y) @test mapreduce(f2, op2, x, Y, dims=1, init=5.0) == mapreduce(f2, op2, x, y, dims=1, init=5.0)