Skip to content

Commit 69e6e28

Browse files
committed
Loosen signature in triangular solver from Strided- to AbstractMatrix.
Remove many unnecessary type parameters. Fixes #16196
1 parent a0f4920 commit 69e6e28

File tree

5 files changed

+72
-17
lines changed

5 files changed

+72
-17
lines changed

base/linalg/bidiag.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,9 @@ end
225225
==(A::Bidiagonal, B::Bidiagonal) = (A.dv==B.dv) && (A.ev==B.ev) && (A.isupper==B.isupper)
226226

227227
SpecialMatrix = Union{Bidiagonal, SymTridiagonal, Tridiagonal, AbstractTriangular}
228-
*(A::SpecialMatrix, B::SpecialMatrix)=full(A)*full(B)
228+
# to avoid ambiguity warning, but shouldn't be necessary
229+
*(A::AbstractTriangular, B::SpecialMatrix) = full(A) * full(B)
230+
*(A::SpecialMatrix, B::SpecialMatrix) = full(A) * full(B)
229231

230232
#Generic multiplication
231233
for func in (:*, :Ac_mul_B, :A_mul_Bc, :/, :A_rdiv_Bc)

base/linalg/diagonal.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,22 @@ end
106106
/{T<:Number}(D::Diagonal, x::T) = Diagonal(D.diag / x)
107107
*(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag .* Db.diag)
108108
*(D::Diagonal, V::AbstractVector) = D.diag .* V
109+
# To avoid ambiguity in the definitions below
110+
for uplo in (:LowerTriangular, :UpperTriangular)
111+
@eval begin
112+
(*)(A::$uplo, D::Diagonal) = $uplo(A.data * D)
113+
114+
function (*)(A::$(Symbol(:Unit, uplo)), D::Diagonal)
115+
B = A.data * D
116+
for i = 1:size(A, 1)
117+
B[i,i] = D.diag[i]
118+
end
119+
return $uplo(B)
120+
end
121+
end
122+
end
123+
(*)(A::AbstractTriangular, D::Diagonal) =
124+
error("this method should never get called. Please make a bug report.")
109125
*(A::AbstractMatrix, D::Diagonal) =
110126
scale!(similar(A, promote_op(*, eltype(A), eltype(D.diag))), A, D.diag)
111127
*(D::Diagonal, A::AbstractMatrix) =

base/linalg/triangular.jl

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,6 +1287,27 @@ function A_rdiv_Bt!(A::StridedMatrix, B::UnitLowerTriangular)
12871287
A
12881288
end
12891289

1290+
for f in (:A_rdiv_B!, :A_rdiv_Bc!, :A_rdiv_Bt!)
1291+
for (uplo, fuplo) in ((:Lower, :tril!), (:Upper, :triu!))
1292+
mat = Symbol(uplo, :Triangular)
1293+
umat = Symbol(:Unit, mat)
1294+
@eval begin
1295+
$f(A::$mat, B::Union{$mat,$umat}) = ($mat)($f($fuplo(A.data), B))
1296+
end
1297+
end
1298+
@eval $f(A::AbstractTriangular, B::AbstractTriangular) = $f(full!(A), B)
1299+
end
1300+
for f in (:A_ldiv_B!, :Ac_ldiv_B!, :At_ldiv_B!)
1301+
for (uplo, fuplo) in ((:Lower, :tril!), (:Upper, :triu!))
1302+
mat = Symbol(uplo, :Triangular)
1303+
umat = Symbol(:Unit, mat)
1304+
@eval begin
1305+
$f(A::Union{$mat,$umat}, B::$mat) = ($mat)($f(A, $fuplo(B.data)))
1306+
end
1307+
end
1308+
@eval $f(A::AbstractTriangular, B::AbstractTriangular) = $f(A, full!(B))
1309+
end
1310+
12901311
# Promotion
12911312
## Promotion methods in matmul don't apply to triangular multiplication since it is inplace. Hence we have to make very similar definitions, but without allocation of a result array. For multiplication and unit diagonal division the element type doesn't have to be stable under division whereas that is necessary in the general triangular solve problem.
12921313

@@ -1297,72 +1318,85 @@ for t in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriang
12971318
end
12981319
end
12991320

1300-
for f in (:*, :Ac_mul_B, :At_mul_B, :\, :Ac_ldiv_B, :At_ldiv_B)
1321+
# for f in (:*, :Ac_mul_B, :At_mul_B, :\, :Ac_ldiv_B, :At_ldiv_B)
1322+
# @eval begin
1323+
# ($f)(A::AbstractTriangular, B::AbstractTriangular) = ($f)(A, full(B))
1324+
# end
1325+
# end
1326+
for f in (:*, :Ac_mul_B, :At_mul_B)
13011327
@eval begin
13021328
($f)(A::AbstractTriangular, B::AbstractTriangular) = ($f)(A, full(B))
13031329
end
13041330
end
1305-
for f in (:A_mul_Bc, :A_mul_Bt, :Ac_mul_Bc, :At_mul_Bt, :/, :A_rdiv_Bc, :A_rdiv_Bt)
1331+
# for f in (:A_mul_Bc, :A_mul_Bt, :Ac_mul_Bc, :At_mul_Bt, :/, :A_rdiv_Bc, :A_rdiv_Bt)
1332+
# @eval begin
1333+
# ($f)(A::AbstractTriangular, B::AbstractTriangular) = ($f)(full(A), B)
1334+
# end
1335+
# end
1336+
for f in (:A_mul_Bc, :A_mul_Bt, :Ac_mul_Bc, :At_mul_Bt)
13061337
@eval begin
13071338
($f)(A::AbstractTriangular, B::AbstractTriangular) = ($f)(full(A), B)
13081339
end
13091340
end
13101341

13111342
## The general promotion methods
1343+
for mat in (:AbstractVector, AbstractMatrix)
13121344
### Multiplication with triangle to the left and hence rhs cannot be transposed.
13131345
for (f, g) in ((:*, :A_mul_B!), (:Ac_mul_B, :Ac_mul_B!), (:At_mul_B, :At_mul_B!))
13141346
@eval begin
1315-
function ($f){TA,TB}(A::AbstractTriangular{TA}, B::StridedVecOrMat{TB})
1316-
TAB = typeof(zero(TA)*zero(TB) + zero(TA)*zero(TB))
1347+
function ($f)(A::AbstractTriangular, B::$mat)
1348+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
13171349
($g)(convert(AbstractArray{TAB}, A), copy_oftype(B, TAB))
13181350
end
13191351
end
13201352
end
13211353
### Left division with triangle to the left hence rhs cannot be transposed. No quotients.
13221354
for (f, g) in ((:\, :A_ldiv_B!), (:Ac_ldiv_B, :Ac_ldiv_B!), (:At_ldiv_B, :At_ldiv_B!))
13231355
@eval begin
1324-
function ($f){TA,TB,S}(A::Union{UnitUpperTriangular{TA,S},UnitLowerTriangular{TA,S}}, B::StridedVecOrMat{TB})
1325-
TAB = typeof(zero(TA)*zero(TB) + zero(TA)*zero(TB))
1356+
function ($f)(A::Union{UnitUpperTriangular,UnitLowerTriangular}, B::$mat)
1357+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
13261358
($g)(convert(AbstractArray{TAB}, A), copy_oftype(B, TAB))
13271359
end
13281360
end
13291361
end
13301362
### Left division with triangle to the left hence rhs cannot be transposed. Quotients.
13311363
for (f, g) in ((:\, :A_ldiv_B!), (:Ac_ldiv_B, :Ac_ldiv_B!), (:At_ldiv_B, :At_ldiv_B!))
13321364
@eval begin
1333-
function ($f){TA,TB,S}(A::Union{UpperTriangular{TA,S},LowerTriangular{TA,S}}, B::StridedVecOrMat{TB})
1334-
TAB = typeof((zero(TA)*zero(TB) + zero(TA)*zero(TB))/one(TA))
1365+
function ($f)(A::Union{UpperTriangular,LowerTriangular}, B::$mat)
1366+
TAB = typeof((zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))/one(eltype(A)))
13351367
($g)(convert(AbstractArray{TAB}, A), copy_oftype(B, TAB))
13361368
end
13371369
end
13381370
end
13391371
### Multiplication with triangle to the rigth and hence lhs cannot be transposed.
13401372
for (f, g) in ((:*, :A_mul_B!), (:A_mul_Bc, :A_mul_Bc!), (:A_mul_Bt, :A_mul_Bt!))
13411373
@eval begin
1342-
function ($f){TA,TB}(A::StridedVecOrMat{TA}, B::AbstractTriangular{TB})
1343-
TAB = typeof(zero(TA)*zero(TB) + zero(TA)*zero(TB))
1374+
function ($f)(A::$mat, B::AbstractTriangular)
1375+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
13441376
($g)(copy_oftype(A, TAB), convert(AbstractArray{TAB}, B))
13451377
end
13461378
end
13471379
end
13481380
### Right division with triangle to the right hence lhs cannot be transposed. No quotients.
13491381
for (f, g) in ((:/, :A_rdiv_B!), (:A_rdiv_Bc, :A_rdiv_Bc!), (:A_rdiv_Bt, :A_rdiv_Bt!))
13501382
@eval begin
1351-
function ($f){TA,TB,S}(A::StridedVecOrMat{TA}, B::Union{UnitUpperTriangular{TB,S},UnitLowerTriangular{TB,S}})
1352-
TAB = typeof(zero(TA)*zero(TB) + zero(TA)*zero(TB))
1383+
function ($f)(A::$mat, B::Union{UnitUpperTriangular,UnitLowerTriangular})
1384+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
13531385
($g)(copy_oftype(A, TAB), convert(AbstractArray{TAB}, B))
13541386
end
13551387
end
13561388
end
1389+
13571390
### Right division with triangle to the right hence lhs cannot be transposed. Quotients.
13581391
for (f, g) in ((:/, :A_rdiv_B!), (:A_rdiv_Bc, :A_rdiv_Bc!), (:A_rdiv_Bt, :A_rdiv_Bt!))
13591392
@eval begin
1360-
function ($f){TA,TB,S}(A::StridedVecOrMat{TA}, B::Union{UpperTriangular{TB,S},LowerTriangular{TB,S}})
1361-
TAB = typeof((zero(TA)*zero(TB) + zero(TA)*zero(TB))/one(TA))
1393+
function ($f)(A::$mat, B::Union{UpperTriangular,LowerTriangular})
1394+
TAB = typeof((zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))/one(eltype(A)))
13621395
($g)(copy_oftype(A, TAB), convert(AbstractArray{TAB}, B))
13631396
end
13641397
end
13651398
end
1399+
end
13661400
### Fallbacks brought in from linalg/bidiag.jl while fixing #14506.
13671401
# Eventually the above promotion methods should be generalized as
13681402
# was done for bidiagonal matrices in #14506.

base/sparse/sparsevector.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,7 +1525,7 @@ for isunittri in (true, false), islowertri in (true, false)
15251525
(true, :(Ac_ldiv_B), :(Ac_ldiv_B!)) )
15261526

15271527
# broad method where elements are Numbers
1528-
@eval function ($func){TA<:Number,Tb<:Number,S}(A::$tritype{TA,S}, b::SparseVector{Tb})
1528+
@eval function ($func){TA<:Number,Tb<:Number,S<:AbstractMatrix}(A::$tritype{TA,S}, b::SparseVector{Tb})
15291529
TAb = $(isunittri ?
15301530
:(typeof(zero(TA)*zero(Tb) + zero(TA)*zero(Tb))) :
15311531
:(typeof((zero(TA)*zero(Tb) + zero(TA)*zero(Tb))/one(TA))) )
@@ -1551,7 +1551,7 @@ for isunittri in (true, false), islowertri in (true, false)
15511551
end
15521552

15531553
# fallback where elements are not Numbers
1554-
@eval ($func){TA,Tb,S}(A::$tritype{TA,S}, b::SparseVector{Tb}) = ($ipfunc)(A, copy(b))
1554+
@eval ($func)(A::$tritype, b::SparseVector) = ($ipfunc)(A, copy(b))
15551555
end
15561556

15571557
# build in-place left-division operations

test/linalg/triangular.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,3 +485,6 @@ let
485485
@test_throws DimensionMismatch A_rdiv_Bt!(A, UnitLowerTriangular(B))
486486
@test_throws DimensionMismatch A_rdiv_Bt!(A, UnitUpperTriangular(B))
487487
end
488+
489+
# Issue 16196
490+
@test UpperTriangular(eye(3)) \ sub(ones(3), [1,2,3]) == ones(3)

0 commit comments

Comments
 (0)