Skip to content

Commit 12803db

Browse files
committed
Merge pull request #14506 from Sacha0/bidiagbackslashprom
fix promotion in bidiagonal backslash, At_ldiv_B, and Ac_ldiv_B. add missing bidiagonal matrix conversions.
2 parents 82f3ad4 + 16c402e commit 12803db

File tree

3 files changed

+56
-27
lines changed

3 files changed

+56
-27
lines changed

base/linalg/bidiag.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ function convert{T}(::Type{Tridiagonal{T}}, A::Bidiagonal)
7777
end
7878
promote_rule{T,S}(::Type{Tridiagonal{T}}, ::Type{Bidiagonal{S}})=Tridiagonal{promote_type(T,S)}
7979

80+
# No-op for trivial conversion Bidiagonal{T} -> Bidiagonal{T}
81+
convert{T}(::Type{Bidiagonal{T}}, A::Bidiagonal{T}) = A
82+
# Convert Bidiagonal{Told} to Bidiagonal{Tnew} by constructing a new instance with converted elements
83+
convert{Tnew,Told}(::Type{Bidiagonal{Tnew}}, A::Bidiagonal{Told}) = Bidiagonal(convert(Vector{Tnew}, A.dv), convert(Vector{Tnew}, A.ev), A.isupper)
84+
# When asked to convert Bidiagonal{Told} to AbstractMatrix{Tnew}, preserve structure by converting to Bidiagonal{Tnew} <: AbstractMatrix{Tnew}
85+
convert{Tnew,Told}(::Type{AbstractMatrix{Tnew}}, A::Bidiagonal{Told}) = convert(Bidiagonal{Tnew}, A)
86+
8087
big(B::Bidiagonal) = Bidiagonal(big(B.dv), big(B.ev), B.isupper)
8188

8289
###################
@@ -230,7 +237,6 @@ function A_ldiv_B!(A::Union{Bidiagonal, AbstractTriangular}, B::AbstractMatrix)
230237
end
231238
B
232239
end
233-
234240
for func in (:Ac_ldiv_B!, :At_ldiv_B!)
235241
@eval function ($func)(A::Union{Bidiagonal, AbstractTriangular}, B::AbstractMatrix)
236242
nA,mA = size(A)
@@ -247,9 +253,6 @@ for func in (:Ac_ldiv_B!, :At_ldiv_B!)
247253
B
248254
end
249255
end
250-
Ac_ldiv_B(A::Union{Bidiagonal, AbstractTriangular}, B::AbstractMatrix) = Ac_ldiv_B!(A,copy(B))
251-
At_ldiv_B(A::Union{Bidiagonal, AbstractTriangular}, B::AbstractMatrix) = At_ldiv_B!(A,copy(B))
252-
253256
#Generic solver using naive substitution
254257
function naivesub!{T}(A::Bidiagonal{T}, b::AbstractVector, x::AbstractVector = b)
255258
N = size(A, 2)
@@ -272,9 +275,15 @@ function naivesub!{T}(A::Bidiagonal{T}, b::AbstractVector, x::AbstractVector = b
272275
x
273276
end
274277

275-
function \{T,S}(A::Bidiagonal{T}, B::AbstractVecOrMat{S})
276-
TS = typeof(zero(T)*zero(S) + zero(T)*zero(S))
277-
TS == S ? A_ldiv_B!(A, copy(B)) : A_ldiv_B!(A, convert(AbstractArray{TS}, B))
278+
### Generic promotion methods and fallbacks
279+
for (f,g) in ((:\, :A_ldiv_B!), (:At_ldiv_B, :At_ldiv_B!), (:Ac_ldiv_B, :Ac_ldiv_B!))
280+
@eval begin
281+
function ($f){TA<:Number,TB<:Number}(A::Bidiagonal{TA}, B::AbstractVecOrMat{TB})
282+
TAB = typeof((zero(TA)*zero(TB) + zero(TA)*zero(TB))/one(TA))
283+
($g)(convert(AbstractArray{TAB}, A), copy_oftype(B, TAB))
284+
end
285+
($f)(A::Bidiagonal, B::AbstractVecOrMat) = ($g)(A, copy(B))
286+
end
278287
end
279288

280289
factorize(A::Bidiagonal) = A

base/linalg/triangular.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,6 +1366,11 @@ for (f, g) in ((:/, :A_rdiv_B!), (:A_rdiv_Bc, :A_rdiv_Bc!), (:A_rdiv_Bt, :A_rdiv
13661366
end
13671367
end
13681368
end
1369+
### Fallbacks brought in from linalg/bidiag.jl while fixing #14506.
1370+
# Eventually the above promotion methods should be generalized as
1371+
# was done for bidiagonal matrices in #14506.
1372+
At_ldiv_B(A::AbstractTriangular, B::AbstractVecOrMat) = At_ldiv_B!(A, copy(B))
1373+
Ac_ldiv_B(A::AbstractTriangular, B::AbstractVecOrMat) = Ac_ldiv_B!(A, copy(B))
13691374

13701375
# Complex matrix logarithm for the upper triangular factor, see:
13711376
# Al-Mohy and Higham, "Improved inverse scaling and squaring algorithms for

test/linalg/bidiag.jl

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,28 @@ n = 10 #Size of test matrix
88
srand(1)
99

1010
debug && println("Bidiagonal matrices")
11-
for relty in (Float32, Float64, BigFloat), elty in (relty, Complex{relty})
11+
for relty in (Int, Float32, Float64, BigFloat), elty in (relty, Complex{relty})
1212
debug && println("elty is $(elty), relty is $(relty)")
13-
dv = convert(Vector{elty}, randn(n))
14-
ev = convert(Vector{elty}, randn(n-1))
15-
b = convert(Matrix{elty}, randn(n, 2))
16-
c = convert(Matrix{elty}, randn(n, n))
17-
if (elty <: Complex)
18-
dv += im*convert(Vector{elty}, randn(n))
19-
ev += im*convert(Vector{elty}, randn(n-1))
20-
b += im*convert(Matrix{elty}, randn(n, 2))
13+
if relty <: AbstractFloat
14+
dv = convert(Vector{elty}, randn(n))
15+
ev = convert(Vector{elty}, randn(n-1))
16+
b = convert(Matrix{elty}, randn(n, 2))
17+
c = convert(Matrix{elty}, randn(n, n))
18+
if (elty <: Complex)
19+
dv += im*convert(Vector{elty}, randn(n))
20+
ev += im*convert(Vector{elty}, randn(n-1))
21+
b += im*convert(Matrix{elty}, randn(n, 2))
22+
end
23+
elseif relty <: Integer
24+
dv = convert(Vector{elty}, rand(1:10, n))
25+
ev = convert(Vector{elty}, rand(1:10, n-1))
26+
b = convert(Matrix{elty}, rand(1:10, n, 2))
27+
c = convert(Matrix{elty}, rand(1:10, n, n))
28+
if (elty <: Complex)
29+
dv += im*convert(Vector{elty}, rand(1:10, n))
30+
ev += im*convert(Vector{elty}, rand(1:10, n-1))
31+
b += im*convert(Matrix{elty}, rand(1:10, n, 2))
32+
end
2133
end
2234

2335
debug && println("Test constructors")
@@ -89,23 +101,24 @@ for relty in (Float32, Float64, BigFloat), elty in (relty, Complex{relty})
89101
condT = cond(map(Complex128,Tfull))
90102
x = T \ b
91103
tx = Tfull \ b
104+
promty = typeof((zero(relty)*zero(relty) + zero(relty)*zero(relty))/one(relty))
92105
@test_throws DimensionMismatch Base.LinAlg.naivesub!(T,ones(elty,n+1))
93-
@test norm(x-tx,Inf) <= 4*condT*max(eps()*norm(tx,Inf), eps(relty)*norm(x,Inf))
106+
@test norm(x-tx,Inf) <= 4*condT*max(eps()*norm(tx,Inf), eps(promty)*norm(x,Inf))
94107
@test_throws DimensionMismatch T \ ones(elty,n+1,2)
95108
@test_throws DimensionMismatch T.' \ ones(elty,n+1,2)
96109
@test_throws DimensionMismatch T' \ ones(elty,n+1,2)
97110
if relty != BigFloat
98111
x = T.'\c.'
99112
tx = Tfull.' \ c.'
100-
@test norm(x-tx,Inf) <= 4*condT*max(eps()*norm(tx,Inf), eps(relty)*norm(x,Inf))
113+
elty <: AbstractFloat && @test norm(x-tx,Inf) <= 4*condT*max(eps()*norm(tx,Inf), eps(promty)*norm(x,Inf))
101114
@test_throws DimensionMismatch T.'\b.'
102115
x = T'\c.'
103116
tx = Tfull' \ c.'
104-
@test norm(x-tx,Inf) <= 4*condT*max(eps()*norm(tx,Inf), eps(relty)*norm(x,Inf))
117+
@test norm(x-tx,Inf) <= 4*condT*max(eps()*norm(tx,Inf), eps(promty)*norm(x,Inf))
105118
@test_throws DimensionMismatch T'\b.'
106119
x = T\c.'
107120
tx = Tfull\c.'
108-
@test norm(x-tx,Inf) <= 4*condT*max(eps()*norm(tx,Inf), eps(relty)*norm(x,Inf))
121+
@test norm(x-tx,Inf) <= 4*condT*max(eps()*norm(tx,Inf), eps(promty)*norm(x,Inf))
109122
@test_throws DimensionMismatch T\b.'
110123
end
111124

@@ -133,11 +146,13 @@ for relty in (Float32, Float64, BigFloat), elty in (relty, Complex{relty})
133146
@test_throws ArgumentError diag(T,n+1)
134147

135148
debug && println("Eigensystems")
136-
d1, v1 = eig(T)
137-
d2, v2 = eig(map(elty<:Complex ? Complex128 : Float64,Tfull))
138-
@test_approx_eq isupper?d1:reverse(d1) d2
139-
if elty <: Real
140-
Test.test_approx_eq_modphase(v1, isupper?v2:v2[:,n:-1:1])
149+
if relty <: AbstractFloat
150+
d1, v1 = eig(T)
151+
d2, v2 = eig(map(elty<:Complex ? Complex128 : Float64,Tfull))
152+
@test_approx_eq isupper?d1:reverse(d1) d2
153+
if elty <: Real
154+
Test.test_approx_eq_modphase(v1, isupper?v2:v2[:,n:-1:1])
155+
end
141156
end
142157

143158
debug && println("Singular systems")
@@ -161,8 +176,8 @@ for relty in (Float32, Float64, BigFloat), elty in (relty, Complex{relty})
161176
@test convert(elty,-1.0) * T == Bidiagonal(-T.dv,-T.ev,T.isupper)
162177
@test T * convert(elty,-1.0) == Bidiagonal(-T.dv,-T.ev,T.isupper)
163178
for isupper2 in (true, false)
164-
dv = convert(Vector{elty}, randn(n))
165-
ev = convert(Vector{elty}, randn(n-1))
179+
dv = convert(Vector{elty}, relty <: AbstractFloat ? randn(n) : rand(1:10, n))
180+
ev = convert(Vector{elty}, relty <: AbstractFloat ? randn(n-1) : rand(1:10, n-1))
166181
T2 = Bidiagonal(dv, ev, isupper2)
167182
Tfull2 = full(T2)
168183
for op in (+, -, *)

0 commit comments

Comments
 (0)