Skip to content

Commit 18bed58

Browse files
authored
Merge pull request #25238 from Sacha0/higho2
optimize and fix map/broadcast over Adjoint/Transpose vectors, take 2
2 parents eb91796 + 5ff0863 commit 18bed58

File tree

2 files changed

+19
-24
lines changed

2 files changed

+19
-24
lines changed

base/linalg/adjtrans.jl

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -136,30 +136,15 @@ hcat(tvs::Transpose{T,Vector{T}}...) where {T} = _transpose_hcat(tvs...)
136136
### higher order functions
137137
# preserve Adjoint/Transpose wrapper around vectors
138138
# to retain the associated semantics post-map/broadcast
139-
140-
# vectorfy takes an Adoint/Transpose-wrapped vector and builds
141-
# an unwrapped vector with the entrywise-same contents
142-
vectorfy(x::Number) = x
143-
vectorfy(adjvec::AdjointAbsVec) = map(Adjoint, adjvec.parent)
144-
vectorfy(transvec::TransposeAbsVec) = map(Transpose, transvec.parent)
145-
vectorfyall(transformedvecs...) = (map(vectorfy, transformedvecs)...,)
146-
147-
# map over collections of Adjoint/Transpose-wrapped vectors
148-
# note that the caller's operation `f` should be applied to the entries of the wrapped
149-
# vectors, rather than the entires of the wrapped vector's parents. so first we use vectorfy
150-
# to build unwrapped vectors with entrywise-same contents as the wrapped input vectors.
151-
# then we map the caller's operation over that set of unwrapped vectors. but now re-wrapping
152-
# the resulting vector would inappropriately transform the result vector's entries. so
153-
# instead of simply mapping the caller's operation over the set of unwrapped vectors,
154-
# we map Adjoint/Transpose composed with the caller's operationt over the set of unwrapped
155-
# vectors. then re-wrapping the result vector yields a wrapped vector with the correct entries.
156-
map(f, avs::AdjointAbsVec...) = Adjoint(map(Adjointf, vectorfyall(avs...)...))
157-
map(f, tvs::TransposeAbsVec...) = Transpose(map(Transposef, vectorfyall(tvs...)...))
158-
159-
# broadcast over collections of Adjoint/Transpose-wrapped vectors and numbers
160-
# similar explanation for these definitions as for map above
161-
broadcast(f, avs::Union{Number,AdjointAbsVec}...) = Adjoint(broadcast(Adjointf, vectorfyall(avs...)...))
162-
broadcast(f, tvs::Union{Number,TransposeAbsVec}...) = Transpose(broadcast(Transposef, vectorfyall(tvs...) ...))
139+
#
140+
# note that the caller's operation f operates in the domain of the wrapped vectors' entries.
141+
# hence the Adjoint->f->Adjoint shenanigans applied to the parent vectors' entries.
142+
map(f, avs::AdjointAbsVec...) = Adjoint(map((xs...) -> Adjoint(f(Adjoint.(xs)...)), parent.(avs)...))
143+
map(f, tvs::TransposeAbsVec...) = Transpose(map((xs...) -> Transpose(f(Transpose.(xs)...)), parent.(tvs)...))
144+
quasiparentt(x) = parent(x); quasiparentt(x::Number) = x # to handle numbers in the defs below
145+
quasiparenta(x) = parent(x); quasiparenta(x::Number) = conj(x) # to handle numbers in the defs below
146+
broadcast(f, avs::Union{Number,AdjointAbsVec}...) = Adjoint(broadcast((xs...) -> Adjoint(f(Adjoint.(xs)...)), quasiparenta.(avs)...))
147+
broadcast(f, tvs::Union{Number,TransposeAbsVec}...) = Transpose(broadcast((xs...) -> Transpose(f(Transpose.(xs)...)), quasiparentt.(tvs)...))
163148

164149

165150
### linear algebra

test/linalg/adjtrans.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,16 @@ end
338338
# trinary broadcast over wrapped vectors with concrete scalar eltype and numbers
339339
@test broadcast(+, Adjoint(vec), 1, Adjoint(vec))::Adjoint{Complex{Int},Vector{Complex{Int}}} == avec + avec .+ 1
340340
@test broadcast(+, Transpose(vec), 1, Transpose(vec))::Transpose{Complex{Int},Vector{Complex{Int}}} == tvec + tvec .+ 1
341+
@test broadcast(+, Adjoint(vec), 1im, Adjoint(vec))::Adjoint{Complex{Int},Vector{Complex{Int}}} == avec + avec .+ 1im
342+
@test broadcast(+, Transpose(vec), 1im, Transpose(vec))::Transpose{Complex{Int},Vector{Complex{Int}}} == tvec + tvec .+ 1im
343+
# ascertain inference friendliness, ref. https://github.com/JuliaLang/julia/pull/25083#issuecomment-353031641
344+
sparsevec = SparseVector([1.0, 2.0, 3.0])
345+
@test map(-, Adjoint(sparsevec), Adjoint(sparsevec)) isa Adjoint{Float64,SparseVector{Float64,Int}}
346+
@test map(-, Transpose(sparsevec), Transpose(sparsevec)) isa Transpose{Float64,SparseVector{Float64,Int}}
347+
@test broadcast(-, Adjoint(sparsevec), Adjoint(sparsevec)) isa Adjoint{Float64,SparseVector{Float64,Int}}
348+
@test broadcast(-, Transpose(sparsevec), Transpose(sparsevec)) isa Transpose{Float64,SparseVector{Float64,Int}}
349+
@test broadcast(+, Adjoint(sparsevec), 1.0, Adjoint(sparsevec)) isa Adjoint{Float64,SparseVector{Float64,Int}}
350+
@test broadcast(+, Transpose(sparsevec), 1.0, Transpose(sparsevec)) isa Transpose{Float64,SparseVector{Float64,Int}}
341351
end
342352

343353
@testset "Adjoint/Transpose-wrapped vector multiplication" begin

0 commit comments

Comments
 (0)