Skip to content

Commit 5759ceb

Browse files
committed
Loosen signature in triangular solver from Strided- to AbstractMatrix.
Remove many unnecessary type parameters. Fixes #16196
1 parent a717fbe commit 5759ceb

File tree

5 files changed

+84
-21
lines changed

5 files changed

+84
-21
lines changed

base/linalg/bidiag.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,6 @@ end
224224
/(A::Bidiagonal, B::Number) = Bidiagonal(A.dv/B, A.ev/B, A.isupper)
225225
==(A::Bidiagonal, B::Bidiagonal) = (A.dv==B.dv) && (A.ev==B.ev) && (A.isupper==B.isupper)
226226

227-
228227
BiTriSym = Union{Bidiagonal, Tridiagonal, SymTridiagonal}
229228
BiTri = Union{Bidiagonal, Tridiagonal}
230229
A_mul_B!(C::AbstractMatrix, A::SymTridiagonal, B::BiTriSym) = A_mul_B_td!(C, A, B)
@@ -361,6 +360,11 @@ function A_mul_B_td!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym)
361360
C
362361
end
363362

363+
SpecialMatrix = Union{Bidiagonal, SymTridiagonal, Tridiagonal, AbstractTriangular}
364+
# to avoid ambiguity warning, but shouldn't be necessary
365+
*(A::AbstractTriangular, B::SpecialMatrix) = full(A) * full(B)
366+
*(A::SpecialMatrix, B::SpecialMatrix) = full(A) * full(B)
367+
364368
#Generic multiplication
365369
for func in (:*, :Ac_mul_B, :A_mul_Bc, :/, :A_rdiv_Bc)
366370
@eval ($func){T}(A::Bidiagonal{T}, B::AbstractVector{T}) = ($func)(full(A), B)

base/linalg/diagonal.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,22 @@ end
104104
/{T<:Number}(D::Diagonal, x::T) = Diagonal(D.diag / x)
105105
*(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag .* Db.diag)
106106
*(D::Diagonal, V::AbstractVector) = D.diag .* V
107+
# To avoid ambiguity in the definitions below
108+
for uplo in (:LowerTriangular, :UpperTriangular)
109+
@eval begin
110+
(*)(A::$uplo, D::Diagonal) = $uplo(A.data * D)
111+
112+
function (*)(A::$(Symbol(:Unit, uplo)), D::Diagonal)
113+
B = A.data * D
114+
for i = 1:size(A, 1)
115+
B[i,i] = D.diag[i]
116+
end
117+
return $uplo(B)
118+
end
119+
end
120+
end
121+
# (*)(A::AbstractTriangular, D::Diagonal) =
122+
# error("this method should never get called. Please make a bug report.")
107123
*(A::AbstractMatrix, D::Diagonal) =
108124
scale!(similar(A, promote_op(*, eltype(A), eltype(D.diag))), A, D.diag)
109125
*(D::Diagonal, A::AbstractMatrix) =

base/linalg/triangular.jl

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,6 +1291,27 @@ function A_rdiv_Bt!(A::StridedMatrix, B::UnitLowerTriangular)
12911291
A
12921292
end
12931293

1294+
for f in (:A_rdiv_B!, :A_rdiv_Bc!, :A_rdiv_Bt!)
1295+
for (uplo, fuplo) in ((:Lower, :tril!), (:Upper, :triu!))
1296+
mat = Symbol(uplo, :Triangular)
1297+
umat = Symbol(:Unit, mat)
1298+
@eval begin
1299+
$f(A::$mat, B::Union{$mat,$umat}) = ($mat)($f($fuplo(A.data), B))
1300+
end
1301+
end
1302+
@eval $f(A::AbstractTriangular, B::AbstractTriangular) = $f(full!(A), B)
1303+
end
1304+
for f in (:A_mul_B!, :Ac_mul_B!, :At_mul_B!, :A_ldiv_B!, :Ac_ldiv_B!, :At_ldiv_B!)
1305+
for (uplo, fuplo) in ((:Lower, :tril!), (:Upper, :triu!))
1306+
mat = Symbol(uplo, :Triangular)
1307+
umat = Symbol(:Unit, mat)
1308+
@eval begin
1309+
$f(A::Union{$mat,$umat}, B::$mat) = ($mat)($f(A, $fuplo(B.data)))
1310+
end
1311+
end
1312+
@eval $f(A::AbstractTriangular, B::AbstractTriangular) = $f(A, full!(B))
1313+
end
1314+
12941315
# Promotion
12951316
## Promotion methods in matmul don't apply to triangular multiplication since it is inplace. Hence we have to make very similar definitions, but without allocation of a result array. For multiplication and unit diagonal division the element type doesn't have to be stable under division whereas that is necessary in the general triangular solve problem.
12961317

@@ -1301,72 +1322,91 @@ for t in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriang
13011322
end
13021323
end
13031324

1304-
for f in (:*, :Ac_mul_B, :At_mul_B, :\, :Ac_ldiv_B, :At_ldiv_B)
1305-
@eval begin
1306-
($f)(A::AbstractTriangular, B::AbstractTriangular) = ($f)(A, full(B))
1325+
for f in (:A_mul_Bc, :A_mul_Bt)
1326+
for (uplo, fuplo) in ((:Lower, :tril!), (:Upper, :triu!))
1327+
mat = Symbol(uplo, :Triangular)
1328+
umat = Symbol(:Unit, mat)
1329+
@eval begin
1330+
function $f(A::$mat, B::Union{$mat,$umat})
1331+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
1332+
return ($mat)($f(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B)))
1333+
end
1334+
end
13071335
end
1336+
@eval $f(A::AbstractTriangular, B::AbstractTriangular) = $f(full(A), B)
13081337
end
1309-
for f in (:A_mul_Bc, :A_mul_Bt, :Ac_mul_Bc, :At_mul_Bt, :/, :A_rdiv_Bc, :A_rdiv_Bt)
1310-
@eval begin
1311-
($f)(A::AbstractTriangular, B::AbstractTriangular) = ($f)(full(A), B)
1338+
for f in (:*, :Ac_mul_B, :At_mul_B)
1339+
for (uplo, fuplo) in ((:Lower, :tril!), (:Upper, :triu!))
1340+
mat = Symbol(uplo, :Triangular)
1341+
umat = Symbol(:Unit, mat)
1342+
@eval begin
1343+
function $f(A::Union{$mat,$umat}, B::$mat)
1344+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
1345+
return ($mat)($f(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB)))
1346+
end
1347+
end
13121348
end
1349+
@eval $f(A::AbstractTriangular, B::AbstractTriangular) = $f(A, full(B))
13131350
end
13141351

13151352
## The general promotion methods
1353+
for mat in (:AbstractVector, AbstractMatrix)
13161354
### Multiplication with triangle to the left and hence rhs cannot be transposed.
13171355
for (f, g) in ((:*, :A_mul_B!), (:Ac_mul_B, :Ac_mul_B!), (:At_mul_B, :At_mul_B!))
13181356
@eval begin
1319-
function ($f){TA,TB}(A::AbstractTriangular{TA}, B::StridedVecOrMat{TB})
1320-
TAB = typeof(zero(TA)*zero(TB) + zero(TA)*zero(TB))
1357+
function ($f)(A::AbstractTriangular, B::$mat)
1358+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
13211359
($g)(convert(AbstractArray{TAB}, A), copy_oftype(B, TAB))
13221360
end
13231361
end
13241362
end
13251363
### Left division with triangle to the left hence rhs cannot be transposed. No quotients.
13261364
for (f, g) in ((:\, :A_ldiv_B!), (:Ac_ldiv_B, :Ac_ldiv_B!), (:At_ldiv_B, :At_ldiv_B!))
13271365
@eval begin
1328-
function ($f){TA,TB,S}(A::Union{UnitUpperTriangular{TA,S},UnitLowerTriangular{TA,S}}, B::StridedVecOrMat{TB})
1329-
TAB = typeof(zero(TA)*zero(TB) + zero(TA)*zero(TB))
1366+
function ($f)(A::Union{UnitUpperTriangular,UnitLowerTriangular}, B::$mat)
1367+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
13301368
($g)(convert(AbstractArray{TAB}, A), copy_oftype(B, TAB))
13311369
end
13321370
end
13331371
end
13341372
### Left division with triangle to the left hence rhs cannot be transposed. Quotients.
13351373
for (f, g) in ((:\, :A_ldiv_B!), (:Ac_ldiv_B, :Ac_ldiv_B!), (:At_ldiv_B, :At_ldiv_B!))
13361374
@eval begin
1337-
function ($f){TA,TB,S}(A::Union{UpperTriangular{TA,S},LowerTriangular{TA,S}}, B::StridedVecOrMat{TB})
1338-
TAB = typeof((zero(TA)*zero(TB) + zero(TA)*zero(TB))/one(TA))
1375+
function ($f)(A::Union{UpperTriangular,LowerTriangular}, B::$mat)
1376+
TAB = typeof((zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))/one(eltype(A)))
13391377
($g)(convert(AbstractArray{TAB}, A), copy_oftype(B, TAB))
13401378
end
13411379
end
13421380
end
13431381
### Multiplication with triangle to the rigth and hence lhs cannot be transposed.
13441382
for (f, g) in ((:*, :A_mul_B!), (:A_mul_Bc, :A_mul_Bc!), (:A_mul_Bt, :A_mul_Bt!))
13451383
@eval begin
1346-
function ($f){TA,TB}(A::StridedVecOrMat{TA}, B::AbstractTriangular{TB})
1347-
TAB = typeof(zero(TA)*zero(TB) + zero(TA)*zero(TB))
1384+
function ($f)(A::$mat, B::AbstractTriangular)
1385+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
13481386
($g)(copy_oftype(A, TAB), convert(AbstractArray{TAB}, B))
13491387
end
13501388
end
13511389
end
13521390
### Right division with triangle to the right hence lhs cannot be transposed. No quotients.
13531391
for (f, g) in ((:/, :A_rdiv_B!), (:A_rdiv_Bc, :A_rdiv_Bc!), (:A_rdiv_Bt, :A_rdiv_Bt!))
13541392
@eval begin
1355-
function ($f){TA,TB,S}(A::StridedVecOrMat{TA}, B::Union{UnitUpperTriangular{TB,S},UnitLowerTriangular{TB,S}})
1356-
TAB = typeof(zero(TA)*zero(TB) + zero(TA)*zero(TB))
1393+
function ($f)(A::$mat, B::Union{UnitUpperTriangular,UnitLowerTriangular})
1394+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
13571395
($g)(copy_oftype(A, TAB), convert(AbstractArray{TAB}, B))
13581396
end
13591397
end
13601398
end
1399+
13611400
### Right division with triangle to the right hence lhs cannot be transposed. Quotients.
13621401
for (f, g) in ((:/, :A_rdiv_B!), (:A_rdiv_Bc, :A_rdiv_Bc!), (:A_rdiv_Bt, :A_rdiv_Bt!))
13631402
@eval begin
1364-
function ($f){TA,TB,S}(A::StridedVecOrMat{TA}, B::Union{UpperTriangular{TB,S},LowerTriangular{TB,S}})
1365-
TAB = typeof((zero(TA)*zero(TB) + zero(TA)*zero(TB))/one(TA))
1403+
function ($f)(A::$mat, B::Union{UpperTriangular,LowerTriangular})
1404+
TAB = typeof((zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))/one(eltype(A)))
13661405
($g)(copy_oftype(A, TAB), convert(AbstractArray{TAB}, B))
13671406
end
13681407
end
13691408
end
1409+
end
13701410
### Fallbacks brought in from linalg/bidiag.jl while fixing #14506.
13711411
# Eventually the above promotion methods should be generalized as
13721412
# was done for bidiagonal matrices in #14506.

base/sparse/sparsevector.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,7 +1525,7 @@ for isunittri in (true, false), islowertri in (true, false)
15251525
(true, :(Ac_ldiv_B), :(Ac_ldiv_B!)) )
15261526

15271527
# broad method where elements are Numbers
1528-
@eval function ($func){TA<:Number,Tb<:Number,S}(A::$tritype{TA,S}, b::SparseVector{Tb})
1528+
@eval function ($func){TA<:Number,Tb<:Number,S<:AbstractMatrix}(A::$tritype{TA,S}, b::SparseVector{Tb})
15291529
TAb = $(isunittri ?
15301530
:(typeof(zero(TA)*zero(Tb) + zero(TA)*zero(Tb))) :
15311531
:(typeof((zero(TA)*zero(Tb) + zero(TA)*zero(Tb))/one(TA))) )
@@ -1551,7 +1551,7 @@ for isunittri in (true, false), islowertri in (true, false)
15511551
end
15521552

15531553
# fallback where elements are not Numbers
1554-
@eval ($func){TA,Tb,S}(A::$tritype{TA,S}, b::SparseVector{Tb}) = ($ipfunc)(A, copy(b))
1554+
@eval ($func)(A::$tritype, b::SparseVector) = ($ipfunc)(A, copy(b))
15551555
end
15561556

15571557
# build in-place left-division operations

test/linalg/triangular.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,3 +489,6 @@ end
489489
# Test that UpperTriangular(LowerTriangular) throws. See #16201
490490
@test_throws ArgumentError LowerTriangular(UpperTriangular(randn(3,3)))
491491
@test_throws ArgumentError UpperTriangular(LowerTriangular(randn(3,3)))
492+
493+
# Issue 16196
494+
@test UpperTriangular(eye(3)) \ sub(ones(3), [1,2,3]) == ones(3)

0 commit comments

Comments
 (0)