Skip to content

Commit 3777aab

Browse files
authored
Fix broadcast with RowVectors of matrices (#20980)
Fix #20979. Amusingly, this bug is a direct result of `transpose` being recursive.
1 parent f0aedc6 commit 3777aab

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

base/linalg/rowvector.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,16 +140,17 @@ end
140140
@inline check_tail_indices(i1, i2, i3, is...) = i3 == 1 ? check_tail_indices(i1, i2, is...) : false
141141

142142
# helper function for below
143-
@inline to_vec(rowvec::RowVector) = transpose(rowvec)
143+
@inline to_vec(rowvec::RowVector) = map(transpose, transpose(rowvec))
144144
@inline to_vec(x::Number) = x
145145
@inline to_vecs(rowvecs...) = (map(to_vec, rowvecs)...)
146146

147-
# map
148-
@inline map(f, rowvecs::RowVector...) = RowVector(map(f, to_vecs(rowvecs...)...))
147+
# map: Preserve the RowVector by un-wrapping and re-wrapping, but note that `f`
148+
# expects to operate within the transposed domain, so to_vec transposes the elements
149+
@inline map(f, rowvecs::RowVector...) = RowVector(map(transposef, to_vecs(rowvecs...)...))
149150

150151
# broacast (other combinations default to higher-dimensional array)
151152
@inline broadcast(f, rowvecs::Union{Number,RowVector}...) =
152-
RowVector(broadcast(f, to_vecs(rowvecs...)...))
153+
RowVector(broadcast(transposef, to_vecs(rowvecs...)...))
153154

154155
# Horizontal concatenation #
155156

test/linalg/rowvector.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,21 @@ end
250250
end
251251
end
252252

253+
@testset "issue #20979" begin
254+
f20979(z::Complex) = [z.re -z.im; z.im z.re]
255+
v = [1+2im]'
256+
@test (f20979.(v))[1] == f20979(v[1])
257+
@test f20979.(v) == f20979.(collect(v))
258+
259+
w = rand(Complex128, 3)
260+
@test f20979.(v') == f20979.(collect(v')) == (f20979.(v))'
261+
262+
g20979(x, y) = [x[2,1] x[1,2]; y[1,2] y[2,1]]
263+
v = [rand(2,2), rand(2,2), rand(2,2)]
264+
@test g20979.(v', v') == g20979.(collect(v'), collect(v')) ==
265+
map(g20979, v', v') == map(g20979, collect(v'), collect(v'))
266+
end
267+
253268
@testset "ambiguity between * methods with RowVectors and ConjRowVectors (#20971)" begin
254269
@test RowVector(ConjArray(ones(4))) * ones(4) == 4
255-
end
270+
end

0 commit comments

Comments
 (0)