diff --git a/Project.toml b/Project.toml index a175d73..52f79a2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,22 +1,29 @@ name = "ArrayLayouts" uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" authors = ["Sheehan Olver "] -version = "1.2.1" +version = "1.3.0" [deps] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +[weakdeps] +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" + +[extensions] +ArrayLayoutsFillArraysExt = "FillArrays" + [compat] FillArrays = "1.2.1" julia = "1.6" [extras] Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Base64", "Random", "StableRNGs", "Test"] +test = ["Base64", "FillArrays", "Random", "StableRNGs", "Test"] diff --git a/ext/ArrayLayoutsFillArraysExt.jl b/ext/ArrayLayoutsFillArraysExt.jl new file mode 100644 index 0000000..b8bcc61 --- /dev/null +++ b/ext/ArrayLayoutsFillArraysExt.jl @@ -0,0 +1,117 @@ +module ArrayLayoutsFillArraysExt + +using FillArrays +using FillArrays: AbstractFill, getindex_value + +using ArrayLayouts +using ArrayLayouts: OnesLayout, Mul, MulAdd, diagonal +import ArrayLayouts: MemoryLayout, _copyto!, sub_materialize, diagonaldata, mulzeros +export layoutfillmul + +import Base: copy, *, +, - +import Base.Broadcast: materialize! +import LinearAlgebra +using LinearAlgebra: Adjoint, Transpose, Symmetric, Hermitian, Diagonal, + AdjointAbsVec, TransposeAbsVec, UniformScaling + +macro layoutfillmul(Typ) + ret = quote + (*)(A::LinearAlgebra.AdjointAbsVec{<:Any,<:Zeros{<:Any,1}}, B::$Typ) = ArrayLayouts.mul(A,B) + (*)(A::LinearAlgebra.TransposeAbsVec{<:Any,<:Zeros{<:Any,1}}, B::$Typ) = ArrayLayouts.mul(A,B) + (*)(A::LinearAlgebra.Transpose{T,<:$Typ}, B::Zeros{T,1}) where T<:Real = ArrayLayouts.mul(A,B) + end + for Mod in (:Adjoint, :Transpose, :Symmetric, :Hermitian) + ret = quote + $ret + + (*)(A::$Mod{<:Any,<:$Typ}, B::Zeros{<:Any,1}) = ArrayLayouts.mul(A,B) + end + end + esc(ret) +end + +@layoutfillmul LayoutMatrix + +*(a::Zeros{<:Any,2}, b::LayoutMatrix) = FillArrays.mult_zeros(a, b) +*(a::LayoutMatrix, b::Zeros{<:Any,2}) = FillArrays.mult_zeros(a, b) +*(a::LayoutMatrix, b::Zeros{<:Any,1}) = FillArrays.mult_zeros(a, b) +*(a::Transpose{T, <:LayoutMatrix{T}} where T, b::Zeros{<:Any, 2}) = FillArrays.mult_zeros(a, b) +*(a::Adjoint{T, <:LayoutMatrix{T}} where T, b::Zeros{<:Any, 2}) = FillArrays.mult_zeros(a, b) +*(A::Adjoint{<:Any, <:Zeros{<:Any,1}}, B::Diagonal{<:Any,<:LayoutVector}) = (B' * A')' +*(A::Transpose{<:Any, <:Zeros{<:Any,1}}, B::Diagonal{<:Any,<:LayoutVector}) = transpose(transpose(B) * transpose(A)) +*(a::Adjoint{<:Number,<:LayoutVector}, b::Zeros{<:Number,1})= FillArrays._adjvec_mul_zeros(a, b) +function *(a::Transpose{T, <:LayoutVector{T}}, b::Zeros{T, 1}) where T<:Real + la, lb = length(a), length(b) + if la ≠ lb + throw(DimensionMismatch("dot product arguments have lengths $la and $lb")) + end + return zero(T) +end + +# equivalent to rescaling +function materialize!(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout}}) + M.B .= getindex_value(M.A.diag) .* M.B + M.B +end +# equivalent to rescaling +function materialize!(M::Rmul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) + M.A .= M.A .* getindex_value(M.B.diag) + M.A +end + +copy(M::Ldiv{<:DiagonalLayout{<:AbstractFillLayout}}) = inv(getindex_value(M.A.diag)) .* M.B +copy(M::Ldiv{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout}) = diagonal(inv(getindex_value(M.A.diag)) .* M.B.diag) + +copy(M::Rdiv{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A .* inv(getindex_value(M.B.diag)) +copy(M::Rdiv{<:DiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = diagonal(M.A.diag .* inv(getindex_value(M.B.diag))) + +copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout}}) = getindex_value(diagonaldata(M.A)) * M.B +copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B +copy(M::Rmul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B)) +copy(M::Rmul{<:DualLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B)) + +copy(M::Rmul{<:BidiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B)) +copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:BidiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B +copy(M::Rmul{<:TridiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B)) +copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:TridiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B +copy(M::Rmul{<:SymTridiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B)) +copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:SymTridiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B + +MemoryLayout(::Type{<:AbstractFill}) = FillLayout() +MemoryLayout(::Type{<:Zeros}) = ZerosLayout() +MemoryLayout(::Type{<:Ones}) = OnesLayout() + +_copyto!(_, ::AbstractFillLayout, dest::AbstractArray{<:Any,N}, src::AbstractArray{<:Any,N}) where N = + fill!(dest, getindex_value(src)) + + _fill_copyto!(dest, C::Zeros) = zero!(dest) # exploit special fill! overload + +sub_materialize(::AbstractFillLayout, V, ax) = Fill(getindex_value(V), ax) +sub_materialize(::ZerosLayout, V, ax) = Zeros{eltype(V)}(ax) +sub_materialize(::OnesLayout, V, ax) = Ones{eltype(V)}(ax) + +*(x::AdjointAbsVec{<:Any,<:Zeros{<:Any,1}}, D::Diagonal, y::LayoutVector) = FillArrays._triple_zeromul(x, D, y) +*(x::TransposeAbsVec{<:Any,<:Zeros{<:Any,1}}, D::Diagonal, y::LayoutVector) = FillArrays._triple_zeromul(x, D, y) + +@inline LinearAlgebra.dot(a::LayoutVector, b::AbstractFill{<:Any,1}) = FillArrays._fill_dot_rev(a,b) +@inline LinearAlgebra.dot(a::AbstractFill{<:Any,1}, b::LayoutVector) = FillArrays._fill_dot(a,b) + +# equivalent to rescaling +function materialize!(M::MulAdd{<:DiagonalLayout{<:AbstractFillLayout}}) + checkdimensions(M) + M.C .= (M.α * getindex_value(M.A.diag)) .* M.B .+ M.β .* M.C + M.C +end + +function materialize!(M::MulAdd{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) + checkdimensions(M) + M.C .= M.α .* M.A .* getindex_value(M.B.diag) .+ M.β .* M.C + M.C +end + +fillzeros(::Type{T}, ax) where T<:Number = Zeros{T}(ax) +mulzeros(::Type{T}, M) where T<:Number = fillzeros(T, axes(M)) +mulzeros(::Type{T}, M::Mul{<:DualLayout,<:Any,<:Adjoint}) where T<:Number = fillzeros(T, axes(M,2))' +mulzeros(::Type{T}, M::Mul{<:DualLayout,<:Any,<:Transpose}) where T<:Number = transpose(fillzeros(T, axes(M,2))) + +end diff --git a/src/ArrayLayouts.jl b/src/ArrayLayouts.jl index b21dc2f..e5fa878 100644 --- a/src/ArrayLayouts.jl +++ b/src/ArrayLayouts.jl @@ -1,6 +1,6 @@ module ArrayLayouts using Base: _typed_hcat -using Base, Base.Broadcast, LinearAlgebra, FillArrays, SparseArrays +using Base, Base.Broadcast, LinearAlgebra, SparseArrays using LinearAlgebra.BLAS using Base: AbstractCartesianIndex, OneTo, oneto, RangeIndex, ReinterpretArray, ReshapedArray, @@ -25,8 +25,6 @@ using LinearAlgebra.BLAS: BlasFloat, BlasReal, BlasComplex AdjointQtype{T} = isdefined(LinearAlgebra, :AdjointQ) ? LinearAlgebra.AdjointQ{T} : Adjoint{T,<:AbstractQ} -using FillArrays: AbstractFill, getindex_value, axes_print_matrix_row - using Base: require_one_based_indexing export materialize, materialize!, MulAdd, muladd!, Ldiv, Rdiv, Lmul, Rmul, Dot, @@ -121,6 +119,11 @@ include("diagonal.jl") include("triangular.jl") include("factorizations.jl") +@static if !isdefined(Base, :get_extension) + include("../ext/ArrayLayoutsFillArraysExt.jl") +end + + # Extend this function if you're only looking to dispatch on the axes @inline sub_materialize_axes(V, _) = Array(V) @inline sub_materialize(_, V, ax) = sub_materialize_axes(V, ax) @@ -196,22 +199,6 @@ getindex(A::LayoutVector, kr::Colon) = layout_getindex(A, kr) getindex(A::AdjOrTrans{<:Any,<:LayoutVector}, kr::Integer, jr::Colon) = layout_getindex(A, kr, jr) getindex(A::AdjOrTrans{<:Any,<:LayoutVector}, kr::Integer, jr::AbstractVector) = layout_getindex(A, kr, jr) -*(a::Zeros{<:Any,2}, b::LayoutMatrix) = FillArrays.mult_zeros(a, b) -*(a::LayoutMatrix, b::Zeros{<:Any,2}) = FillArrays.mult_zeros(a, b) -*(a::LayoutMatrix, b::Zeros{<:Any,1}) = FillArrays.mult_zeros(a, b) -*(a::Transpose{T, <:LayoutMatrix{T}} where T, b::Zeros{<:Any, 2}) = FillArrays.mult_zeros(a, b) -*(a::Adjoint{T, <:LayoutMatrix{T}} where T, b::Zeros{<:Any, 2}) = FillArrays.mult_zeros(a, b) -*(A::Adjoint{<:Any, <:Zeros{<:Any,1}}, B::Diagonal{<:Any,<:LayoutVector}) = (B' * A')' -*(A::Transpose{<:Any, <:Zeros{<:Any,1}}, B::Diagonal{<:Any,<:LayoutVector}) = transpose(transpose(B) * transpose(A)) -*(a::Adjoint{<:Number,<:LayoutVector}, b::Zeros{<:Number,1})= FillArrays._adjvec_mul_zeros(a, b) -function *(a::Transpose{T, <:LayoutVector{T}}, b::Zeros{T, 1}) where T<:Real - la, lb = length(a), length(b) - if la ≠ lb - throw(DimensionMismatch("dot product arguments have lengths $la and $lb")) - end - return zero(T) -end - *(A::Diagonal{<:Any,<:LayoutVector}, B::Diagonal{<:Any,<:LayoutVector}) = mul(A, B) *(A::Diagonal{<:Any,<:LayoutVector}, B::AbstractMatrix) = mul(A, B) *(A::AbstractMatrix, B::Diagonal{<:Any,<:LayoutVector}) = mul(A, B) @@ -385,6 +372,16 @@ Base.replace_in_print_matrix(A::Union{LayoutVector, UnitLowerTriangular{<:Any,<:AdjOrTrans{<:Any,<:LayoutMatrix}}}, i::Integer, j::Integer, s::AbstractString) = layout_replace_in_print_matrix(MemoryLayout(A), A, i, j, s) +if VERSION < v"1.8-" + axes_print_matrix_row(lay, io, X, A, i, cols, sep) = + Base.invoke(Base.print_matrix_row, Tuple{IO,AbstractVecOrMat,Vector,Integer,AbstractVector,AbstractString}, + io, X, A, i, cols, sep) +else + axes_print_matrix_row(lay, io, X, A, i, cols, sep, idxlast::Integer=last(axes(X, 2))) = + Base.invoke(Base.print_matrix_row, Tuple{IO,AbstractVecOrMat,Vector,Integer,AbstractVector,AbstractString,Integer}, + io, X, A, i, cols, sep, idxlast) +end + Base.print_matrix_row(io::IO, X::Union{LayoutMatrix, LayoutVector, diff --git a/src/diagonal.jl b/src/diagonal.jl index 1a3e26d..22b86b0 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -8,11 +8,6 @@ mulreduce(M::Mul{<:Any,<:DiagonalLayout}) = Rmul(M) # Diagonal multiplication never changes structure similar(M::Lmul{<:DiagonalLayout}, ::Type{T}, axes) where T = similar(M.B, T, axes) -# equivalent to rescaling -function materialize!(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout}}) - M.B .= getindex_value(M.A.diag) .* M.B - M.B -end copy(M::Lmul{<:DiagonalLayout,<:DiagonalLayout}) = diagonal(diagonaldata(M.A) .* diagonaldata(M.B)) @@ -32,11 +27,6 @@ end # Diagonal multiplication never changes structure similar(M::Rmul{<:Any,<:DiagonalLayout}, ::Type{T}, axes) where T = similar(M.A, T, axes) -# equivalent to rescaling -function materialize!(M::Rmul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) - M.A .= M.A .* getindex_value(M.B.diag) - M.A -end function materialize!(M::Ldiv{<:DiagonalLayout}) @@ -46,13 +36,9 @@ end copy(M::Ldiv{<:DiagonalLayout,<:DiagonalLayout}) = diagonal(M.A.diag .\ M.B.diag) copy(M::Ldiv{<:DiagonalLayout}) = M.A.diag .\ M.B -copy(M::Ldiv{<:DiagonalLayout{<:AbstractFillLayout}}) = inv(getindex_value(M.A.diag)) .* M.B -copy(M::Ldiv{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout}) = diagonal(inv(getindex_value(M.A.diag)) .* M.B.diag) copy(M::Rdiv{<:DiagonalLayout,<:DiagonalLayout}) = diagonal(M.A.diag .* inv.(M.B.diag)) copy(M::Rdiv{<:Any,<:DiagonalLayout}) = M.A .* inv.(permutedims(M.B.diag)) -copy(M::Rdiv{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A .* inv(getindex_value(M.B.diag)) -copy(M::Rdiv{<:DiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = diagonal(M.A.diag .* inv(getindex_value(M.B.diag))) @@ -71,18 +57,6 @@ copy(M::Lmul{DiagonalLayout{OnesLayout},DiagonalLayout{OnesLayout}}) = _copy_oft copy(M::Rmul{<:Any,DiagonalLayout{OnesLayout}}) = _copy_oftype(M.A, eltype(M)) copy(M::Rmul{<:DualLayout,DiagonalLayout{OnesLayout}}) = _copy_oftype(M.A, eltype(M)) -copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout}}) = getindex_value(diagonaldata(M.A)) * M.B -copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B -copy(M::Rmul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B)) -copy(M::Rmul{<:DualLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B)) - -copy(M::Rmul{<:BidiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B)) -copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:BidiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B -copy(M::Rmul{<:TridiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B)) -copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:TridiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B -copy(M::Rmul{<:SymTridiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B)) -copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:SymTridiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B - copy(M::Rmul{<:BidiagonalLayout,DiagonalLayout{OnesLayout}}) = _copy_oftype(M.A, eltype(M)) copy(M::Lmul{DiagonalLayout{OnesLayout},<:BidiagonalLayout}) = _copy_oftype(M.B, eltype(M)) diff --git a/src/memorylayout.jl b/src/memorylayout.jl index f7b9e46..b5a4040 100644 --- a/src/memorylayout.jl +++ b/src/memorylayout.jl @@ -531,23 +531,12 @@ struct ZerosLayout <: AbstractFillLayout end struct OnesLayout <: AbstractFillLayout end struct EyeLayout <: MemoryLayout end -MemoryLayout(::Type{<:AbstractFill}) = FillLayout() -MemoryLayout(::Type{<:Zeros}) = ZerosLayout() -MemoryLayout(::Type{<:Ones}) = OnesLayout() - # all sub arrays are same sublayout(L::AbstractFillLayout, inds::Type) = L reshapedlayout(L::AbstractFillLayout, _) = L adjointlayout(::Type, L::AbstractFillLayout) = L transposelayout(L::AbstractFillLayout) = L -_copyto!(_, ::AbstractFillLayout, dest::AbstractArray{<:Any,N}, src::AbstractArray{<:Any,N}) where N = - fill!(dest, getindex_value(src)) - -sub_materialize(::AbstractFillLayout, V, ax) = Fill(getindex_value(V), ax) -sub_materialize(::ZerosLayout, V, ax) = Zeros{eltype(V)}(ax) -sub_materialize(::OnesLayout, V, ax) = Ones{eltype(V)}(ax) - abstract type AbstractBandedLayout <: MemoryLayout end abstract type AbstractTridiagonalLayout <: AbstractBandedLayout end diff --git a/src/mul.jl b/src/mul.jl index e23d703..d411e54 100644 --- a/src/mul.jl +++ b/src/mul.jl @@ -213,9 +213,6 @@ macro layoutmul(Typ) (*)(A::AbstractMatrix, B::$Typ) = ArrayLayouts.mul(A,B) (*)(A::LinearAlgebra.AdjointAbsVec, B::$Typ) = ArrayLayouts.mul(A,B) (*)(A::LinearAlgebra.TransposeAbsVec, B::$Typ) = ArrayLayouts.mul(A,B) - (*)(A::LinearAlgebra.AdjointAbsVec{<:Any,<:Zeros{<:Any,1}}, B::$Typ) = ArrayLayouts.mul(A,B) - (*)(A::LinearAlgebra.TransposeAbsVec{<:Any,<:Zeros{<:Any,1}}, B::$Typ) = ArrayLayouts.mul(A,B) - (*)(A::LinearAlgebra.Transpose{T,<:$Typ}, B::Zeros{T,1}) where T<:Real = ArrayLayouts.mul(A,B) (*)(A::LinearAlgebra.AbstractQ, B::$Typ) = ArrayLayouts.mul(A,B) (*)(A::$Typ, B::LinearAlgebra.AbstractQ) = ArrayLayouts.mul(A,B) @@ -278,7 +275,6 @@ macro layoutmul(Typ) (*)(A::LinearAlgebra.TransposeAbsVec, B::$Mod{<:Any,<:$Typ}) = ArrayLayouts.mul(A,B) (*)(A::$Mod{<:Any,<:$Typ}, B::AbstractVector) = ArrayLayouts.mul(A,B) (*)(A::$Mod{<:Any,<:$Typ}, B::ArrayLayouts.LayoutVector) = ArrayLayouts.mul(A,B) - (*)(A::$Mod{<:Any,<:$Typ}, B::Zeros{<:Any,1}) = ArrayLayouts.mul(A,B) (*)(A::$Mod{<:Any,<:$Typ}, B::$Typ) = ArrayLayouts.mul(A,B) (*)(A::$Typ, B::$Mod{<:Any,<:$Typ}) = ArrayLayouts.mul(A,B) @@ -306,8 +302,6 @@ end *(x::Transpose{<:Any,<:LayoutVector}, D::Diagonal{<:Any,<:LayoutVector}) = mul(x, D) *(x::AdjointAbsVec, D::Diagonal, y::LayoutVector) = x * mul(D,y) *(x::TransposeAbsVec, D::Diagonal, y::LayoutVector) = x * mul(D,y) -*(x::AdjointAbsVec{<:Any,<:Zeros{<:Any,1}}, D::Diagonal, y::LayoutVector) = FillArrays._triple_zeromul(x, D, y) -*(x::TransposeAbsVec{<:Any,<:Zeros{<:Any,1}}, D::Diagonal, y::LayoutVector) = FillArrays._triple_zeromul(x, D, y) *(A::UpperOrLowerTriangular{<:Any,<:LayoutMatrix}, B::UpperOrLowerTriangular{<:Any,<:LayoutMatrix}) = mul(A, B) @@ -358,8 +352,6 @@ dot(a, b) = materialize(Dot(a, b)) @inline LinearAlgebra.dot(a::LayoutArray, b::LayoutArray) = dot(a,b) @inline LinearAlgebra.dot(a::LayoutArray, b::AbstractArray) = dot(a,b) @inline LinearAlgebra.dot(a::AbstractArray, b::LayoutArray) = dot(a,b) -@inline LinearAlgebra.dot(a::LayoutVector, b::AbstractFill{<:Any,1}) = FillArrays._fill_dot_rev(a,b) -@inline LinearAlgebra.dot(a::AbstractFill{<:Any,1}, b::LayoutVector) = FillArrays._fill_dot(a,b) @inline LinearAlgebra.dot(a::LayoutArray{<:Number}, b::SparseArrays.SparseVectorUnion{<:Number}) = dot(a,b) @inline LinearAlgebra.dot(a::SparseArrays.SparseVectorUnion{<:Number}, b::LayoutArray{<:Number}) = dot(a,b) @@ -379,17 +371,9 @@ LinearAlgebra.dot(x::AbstractVector, A::Symmetric{<:Real,<:LayoutMatrix}, y::Abs # allow overloading for infinite or lazy case @inline _power_by_squaring(_, _, A, p) = invoke(Base.power_by_squaring, Tuple{AbstractMatrix,Integer}, A, p) -# TODO: Remove unnecessary _apply -_apply(_, _, op, Λ::UniformScaling, A::AbstractMatrix) = op(Diagonal(Fill(Λ.λ,(axes(A,1),))), A) -_apply(_, _, op, A::AbstractMatrix, Λ::UniformScaling) = op(A, Diagonal(Fill(Λ.λ,(axes(A,1),)))) for Typ in (:LayoutMatrix, :(Symmetric{<:Any,<:LayoutMatrix}), :(Hermitian{<:Any,<:LayoutMatrix}), :(Adjoint{<:Any,<:LayoutMatrix}), :(Transpose{<:Any,<:LayoutMatrix})) - @eval begin - @inline Base.power_by_squaring(A::$Typ, p::Integer) = _power_by_squaring(MemoryLayout(A), size(A), A, p) - @inline +(A::$Typ, Λ::UniformScaling) = _apply(MemoryLayout(A), size(A), +, A, Λ) - @inline +(Λ::UniformScaling, A::$Typ) = _apply(MemoryLayout(A), size(A), +, Λ, A) - @inline -(A::$Typ, Λ::UniformScaling) = _apply(MemoryLayout(A), size(A), -, A, Λ) - @inline -(Λ::UniformScaling, A::$Typ) = _apply(MemoryLayout(A), size(A), -, Λ, A) - end + @eval @inline Base.power_by_squaring(A::$Typ, p::Integer) = + _power_by_squaring(MemoryLayout(A), size(A), A, p) end diff --git a/src/muladd.jl b/src/muladd.jl index 2873d21..2eed8f2 100644 --- a/src/muladd.jl +++ b/src/muladd.jl @@ -72,7 +72,6 @@ materialize(M::MulAdd) = copy(instantiate(M)) copy(M::MulAdd) = copyto!(similar(M), M) _fill_copyto!(dest, C) = copyto!(dest, C) -_fill_copyto!(dest, C::Zeros) = zero!(dest) # exploit special fill! overload @inline copyto!(dest::AbstractArray{T}, M::MulAdd) where T = muladd!(M.α, unalias(dest,M.A), unalias(dest,M.B), M.β, _fill_copyto!(dest, M.C)) @@ -362,18 +361,6 @@ materialize!(M::BlasMatMulVecAdd{<:HermitianLayout{<:AbstractRowMajor},<:Abstrac similar(M::MulAdd{<:DiagonalLayout,<:DiagonalLayout}, ::Type{T}, axes) where T = similar(M.B, T, axes) similar(M::MulAdd{<:DiagonalLayout}, ::Type{T}, axes) where T = similar(M.B, T, axes) similar(M::MulAdd{<:Any,<:DiagonalLayout}, ::Type{T}, axes) where T = similar(M.A, T, axes) -# equivalent to rescaling -function materialize!(M::MulAdd{<:DiagonalLayout{<:AbstractFillLayout}}) - checkdimensions(M) - M.C .= (M.α * getindex_value(M.A.diag)) .* M.B .+ M.β .* M.C - M.C -end - -function materialize!(M::MulAdd{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) - checkdimensions(M) - M.C .= M.α .* M.A .* getindex_value(M.B.diag) .+ M.β .* M.C - M.C -end BroadcastStyle(::Type{<:MulAdd}) = ApplyBroadcastStyle() @@ -383,11 +370,6 @@ scalarone(::Type{A}) where {A<:AbstractArray} = scalarone(eltype(A)) scalarzero(::Type{T}) where T = zero(T) scalarzero(::Type{A}) where {A<:AbstractArray} = scalarzero(eltype(A)) -fillzeros(::Type{T}, ax) where T<:Number = Zeros{T}(ax) -mulzeros(::Type{T}, M) where T<:Number = fillzeros(T, axes(M)) -mulzeros(::Type{T}, M::Mul{<:DualLayout,<:Any,<:Adjoint}) where T<:Number = fillzeros(T, axes(M,2))' -mulzeros(::Type{T}, M::Mul{<:DualLayout,<:Any,<:Transpose}) where T<:Number = transpose(fillzeros(T, axes(M,2))) - # initiate array-valued MulAdd function _mulzeros!(dest::AbstractVector{T}, A, B) where T for k in axes(dest,1)