diff --git a/src/host/mapreduce.jl b/src/host/mapreduce.jl index 32520ebc..c589a3de 100644 --- a/src/host/mapreduce.jl +++ b/src/host/mapreduce.jl @@ -10,6 +10,30 @@ mapreducedim!(f, op, R::AnyGPUArray, A::AbstractArrayOrBroadcasted; Base.mapreducedim!(f, op, R::AnyGPUArray, A::AbstractArray) = mapreducedim!(f, op, R, A) Base.mapreducedim!(f, op, R::AnyGPUArray, A::Broadcast.Broadcasted) = mapreducedim!(f, op, R, A) +# resolve ambiguities with Adjoint/Transpose +# https://github.com/JuliaLang/julia/blob/master/stdlib/LinearAlgebra/src/adjtrans.jl#L440-L448 +Base.mapreducedim!(f, op::LinearAlgebra.CommutativeOps, B::AnyGPUArray, A::LinearAlgebra.TransposeAbsMat) = + (Base.mapreducedim!(f∘transpose, op, LinearAlgebra.switch_dim12(B), parent(A)); B) +Base.mapreducedim!(f, op::LinearAlgebra.CommutativeOps, B::AnyGPUArray, A::LinearAlgebra.AdjointAbsMat) = + (Base.mapreducedim!(f∘adjoint, op, LinearAlgebra.switch_dim12(B), parent(A)); B) +Base.mapreducedim!(f::typeof(identity), op::Union{typeof(*),typeof(Base.mul_prod)}, B::AnyGPUArray, A::LinearAlgebra.TransposeAbsMat{<:Union{Real,Complex}}) = + (Base.mapreducedim!(f∘transpose, op, LinearAlgebra.switch_dim12(B), parent(A)); B) +Base.mapreducedim!(f::typeof(identity), op::Union{typeof(*),typeof(Base.mul_prod)}, B::AnyGPUArray, A::LinearAlgebra.AdjointAbsMat{<:Union{Real,Complex}}) = + (Base.mapreducedim!(f∘adjoint, op, LinearAlgebra.switch_dim12(B), parent(A)); B) + +# resolve ambiguities with PermutedDimsArray +# https://github.com/JuliaLang/julia/blob/master/base/permuteddimsarray.jl#L355-L364 +function Base.mapreducedim!(f, op::LinearAlgebra.CommutativeOps, B::AnyGPUArray{T,N}, A::PermutedDimsArray{S,N,perm,iperm}) where {T,S,N,perm,iperm} + C = PermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output + Base.mapreducedim!(f, op, C, parent(A)) + B +end +function Base.mapreducedim!(f::typeof(identity), op::Union{typeof(Base.mul_prod),typeof(*)}, B::AnyGPUArray{T,N}, A::PermutedDimsArray{<:Union{Real,Complex},N,perm,iperm}) where {T,N,perm,iperm} + C = PermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output + Base.mapreducedim!(f, op, C, parent(A)) + B +end + neutral_element(op, T) = error("""GPUArrays.jl needs to know the neutral element for your operator `$op`. Please pass it as an explicit argument to `GPUArrays.mapreducedim!`,