Skip to content

Commit 5054941

Browse files
committed
Further improvements
1 parent 5759ceb commit 5054941

File tree

4 files changed

+150
-52
lines changed

4 files changed

+150
-52
lines changed

base/linalg/bidiag.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ function A_mul_B_td!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym)
360360
C
361361
end
362362

363-
SpecialMatrix = Union{Bidiagonal, SymTridiagonal, Tridiagonal, AbstractTriangular}
363+
SpecialMatrix = Union{Bidiagonal, SymTridiagonal, Tridiagonal}
364364
# to avoid ambiguity warning, but shouldn't be necessary
365365
*(A::AbstractTriangular, B::SpecialMatrix) = full(A) * full(B)
366366
*(A::SpecialMatrix, B::SpecialMatrix) = full(A) * full(B)

base/linalg/diagonal.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,6 @@ for uplo in (:LowerTriangular, :UpperTriangular)
118118
end
119119
end
120120
end
121-
# (*)(A::AbstractTriangular, D::Diagonal) =
122-
# error("this method should never get called. Please make a bug report.")
123121
*(A::AbstractMatrix, D::Diagonal) =
124122
scale!(similar(A, promote_op(*, eltype(A), eltype(D.diag))), A, D.diag)
125123
*(D::Diagonal, A::AbstractMatrix) =

base/linalg/triangular.jl

Lines changed: 148 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,25 +1291,48 @@ function A_rdiv_Bt!(A::StridedMatrix, B::UnitLowerTriangular)
12911291
A
12921292
end
12931293

1294-
for f in (:A_rdiv_B!, :A_rdiv_Bc!, :A_rdiv_Bt!)
1295-
for (uplo, fuplo) in ((:Lower, :tril!), (:Upper, :triu!))
1296-
mat = Symbol(uplo, :Triangular)
1297-
umat = Symbol(:Unit, mat)
1298-
@eval begin
1299-
$f(A::$mat, B::Union{$mat,$umat}) = ($mat)($f($fuplo(A.data), B))
1300-
end
1294+
for f in (:A_mul_B!, :A_ldiv_B!)
1295+
@eval begin
1296+
$f(A::UpperTriangular, B::UpperTriangular) =
1297+
UpperTriangular($f(A, triu!(B.data)))
1298+
$f(A::UnitUpperTriangular, B::UpperTriangular) =
1299+
UpperTriangular($f(A, triu!(B.data)))
1300+
$f(A::UpperTriangular, B::UnitUpperTriangular) =
1301+
UpperTriangular($f(A, triu!(B.data)))
1302+
$f(A::UnitUpperTriangular, B::UnitUpperTriangular) =
1303+
UnitUpperTriangular($f(A, triu!(B.data)))
1304+
$f(A::LowerTriangular, B::LowerTriangular) =
1305+
LowerTriangular($f(A, tril!(B.data)))
1306+
$f(A::UnitLowerTriangular, B::LowerTriangular) =
1307+
LowerTriangular($f(A, tril!(B.data)))
1308+
$f(A::LowerTriangular, B::UnitLowerTriangular) =
1309+
LowerTriangular($f(A, tril!(B.data)))
1310+
$f(A::UnitLowerTriangular, B::UnitLowerTriangular) =
1311+
LowerTriangular($f(A, tril!(B.data)))
1312+
end
1313+
end
1314+
1315+
for f in (:Ac_mul_B!, :At_mul_B!, :Ac_ldiv_B!, :At_ldiv_B!)
1316+
@eval begin
1317+
$f(A::Union{LowerTriangular,UnitLowerTriangular}, B::UpperTriangular) =
1318+
UpperTriangular($f(A, triu!(B.data)))
1319+
$f(A::Union{UpperTriangular,UnitUpperTriangular}, B::LowerTriangular) =
1320+
LowerTriangular($f(A, tril!(B.data)))
13011321
end
1302-
@eval $f(A::AbstractTriangular, B::AbstractTriangular) = $f(full!(A), B)
13031322
end
1304-
for f in (:A_mul_B!, :Ac_mul_B!, :At_mul_B!, :A_ldiv_B!, :Ac_ldiv_B!, :At_ldiv_B!)
1305-
for (uplo, fuplo) in ((:Lower, :tril!), (:Upper, :triu!))
1306-
mat = Symbol(uplo, :Triangular)
1307-
umat = Symbol(:Unit, mat)
1308-
@eval begin
1309-
$f(A::Union{$mat,$umat}, B::$mat) = ($mat)($f(A, $fuplo(B.data)))
1310-
end
1323+
1324+
A_rdiv_B!(A::UpperTriangular, B::Union{UpperTriangular,UnitUpperTriangular}) =
1325+
UpperTriangular(A_rdiv_B!(triu!(A.data), B))
1326+
A_rdiv_B!(A::LowerTriangular, B::Union{LowerTriangular,UnitLowerTriangular}) =
1327+
LowerTriangular(A_rdiv_B!(tril!(A.data), B))
1328+
1329+
for f in (:A_mul_Bc!, :A_mul_Bt!, :A_rdiv_Bc!, :A_rdiv_Bt!)
1330+
@eval begin
1331+
$f(A::UpperTriangular, B::Union{LowerTriangular,UnitLowerTriangular}) =
1332+
UpperTriangular($f(triu!(A.data), B))
1333+
$f(A::LowerTriangular, B::Union{UpperTriangular,UnitUpperTriangular}) =
1334+
LowerTriangular($f(tril!(A.data), B))
13111335
end
1312-
@eval $f(A::AbstractTriangular, B::AbstractTriangular) = $f(A, full!(B))
13131336
end
13141337

13151338
# Promotion
@@ -1318,45 +1341,104 @@ end
13181341
## Some Triangular-Triangular cases. We might want to write taylored methods for these cases, but I'm not sure it is worth it.
13191342
for t in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriangular)
13201343
@eval begin
1321-
*(A::Tridiagonal, B::$t) = A_mul_B!(full(A), B)
1344+
(*)(A::Tridiagonal, B::$t) = A_mul_B!(full(A), B)
13221345
end
13231346
end
13241347

1325-
for f in (:A_mul_Bc, :A_mul_Bt)
1326-
for (uplo, fuplo) in ((:Lower, :tril!), (:Upper, :triu!))
1327-
mat = Symbol(uplo, :Triangular)
1328-
umat = Symbol(:Unit, mat)
1329-
@eval begin
1330-
function $f(A::$mat, B::Union{$mat,$umat})
1331-
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
1332-
return ($mat)($f(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B)))
1333-
end
1348+
for (f1, f2) in ((:*, :A_mul_B!), (:\, :A_ldiv_B))
1349+
@eval begin
1350+
function $f1(A::Union{LowerTriangular, UnitLowerTriangular}, B::LowerTriangular)
1351+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
1352+
return LowerTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB)))
1353+
end
1354+
1355+
function $f1(A::Union{UpperTriangular, UnitUpperTriangular}, B::UpperTriangular)
1356+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
1357+
return UpperTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB)))
13341358
end
13351359
end
1336-
@eval $f(A::AbstractTriangular, B::AbstractTriangular) = $f(full(A), B)
13371360
end
1338-
for f in (:*, :Ac_mul_B, :At_mul_B)
1339-
for (uplo, fuplo) in ((:Lower, :tril!), (:Upper, :triu!))
1340-
mat = Symbol(uplo, :Triangular)
1341-
umat = Symbol(:Unit, mat)
1342-
@eval begin
1343-
function $f(A::Union{$mat,$umat}, B::$mat)
1344-
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
1345-
return ($mat)($f(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB)))
1346-
end
1361+
1362+
for (f1, f2) in ((:Ac_mul_B, :Ac_mul_B!), (:At_mul_B, :At_mul_B!),
1363+
(:Ac_ldiv_B, Ac_ldiv_B!), (:At_ldiv_B, :At_ldiv_B!))
1364+
@eval begin
1365+
function $f1(A::Union{UpperTriangular, UnitUpperTriangular}, B::LowerTriangular)
1366+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
1367+
return LowerTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB)))
1368+
end
1369+
1370+
function $f1(A::Union{LowerTriangular, UnitLowerTriangular}, B::UpperTriangular)
1371+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
1372+
return UpperTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB)))
1373+
end
1374+
end
1375+
end
1376+
1377+
function (/)(A::LowerTriangular, B::LowerTriangular)
1378+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))/
1379+
one(eltype(A)))
1380+
return LowerTriangular(A_rdiv_B!(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B)))
1381+
end
1382+
function (/)(A::LowerTriangular, B::UnitLowerTriangular)
1383+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
1384+
return LowerTriangular(A_rdiv_B!(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B)))
1385+
end
1386+
function (/)(A::UpperTriangular, B::UpperTriangular)
1387+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))/
1388+
one(eltype(A)))
1389+
return UpperTriangular(A_rdiv_B!(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B)))
1390+
end
1391+
function (/)(A::UpperTriangular, B::UnitUpperTriangular)
1392+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
1393+
return UpperTriangular(A_rdiv_B!(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B)))
1394+
end
1395+
1396+
for (f1, f2) in ((:A_mul_Bc, :A_mul_Bc!), (:A_mul_Bt, :A_mul_Bt!),
1397+
(:A_rdiv_Bc, :A_rdiv_Bc!), (:A_rdiv_Bt, :A_rdiv_Bt!))
1398+
@eval begin
1399+
function $f1(A::LowerTriangular, B::Union{UpperTriangular, UnitUpperTriangular})
1400+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
1401+
return LowerTriangular($f2(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B)))
1402+
end
1403+
1404+
function $f1(A::UpperTriangular, B::LowerTriangular)
1405+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
1406+
return UpperTriangular($f2(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B)))
1407+
end
1408+
1409+
function $f1(A::UpperTriangular, B::UnitLowerTriangular)
1410+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
1411+
return UpperTriangular($f2(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B)))
13471412
end
13481413
end
1349-
@eval $f(A::AbstractTriangular, B::AbstractTriangular) = $f(A, full(B))
13501414
end
13511415

1416+
# A_mul_Bc(A::AbstractTriangular, B::AbstractTriangular) = A_mul_Bc!(full(A), B)
1417+
# A_mul_Bt(A::AbstractTriangular, B::AbstractTriangular) = A_mul_Bt!(full(A), B)
1418+
13521419
## The general promotion methods
1420+
1421+
for (f, g) in ((:*, :A_mul_B!), (:Ac_mul_B, :Ac_mul_B!), (:At_mul_B, :At_mul_B!))
1422+
@eval begin
1423+
function ($f)(A::AbstractTriangular, B::AbstractTriangular)
1424+
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
1425+
BB = similar(B, TAB, size(B))
1426+
copy!(BB, B)
1427+
($g)(convert(AbstractArray{TAB}, A), BB)
1428+
end
1429+
end
1430+
end
1431+
13531432
for mat in (:AbstractVector, AbstractMatrix)
1433+
13541434
### Multiplication with triangle to the left and hence rhs cannot be transposed.
13551435
for (f, g) in ((:*, :A_mul_B!), (:Ac_mul_B, :Ac_mul_B!), (:At_mul_B, :At_mul_B!))
13561436
@eval begin
13571437
function ($f)(A::AbstractTriangular, B::$mat)
13581438
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
1359-
($g)(convert(AbstractArray{TAB}, A), copy_oftype(B, TAB))
1439+
BB = similar(B, TAB, size(B))
1440+
copy!(BB, B)
1441+
($g)(convert(AbstractArray{TAB}, A), BB)
13601442
end
13611443
end
13621444
end
@@ -1365,7 +1447,9 @@ for (f, g) in ((:\, :A_ldiv_B!), (:Ac_ldiv_B, :Ac_ldiv_B!), (:At_ldiv_B, :At_ldi
13651447
@eval begin
13661448
function ($f)(A::Union{UnitUpperTriangular,UnitLowerTriangular}, B::$mat)
13671449
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
1368-
($g)(convert(AbstractArray{TAB}, A), copy_oftype(B, TAB))
1450+
BB = similar(B, TAB, size(B))
1451+
copy!(BB, B)
1452+
($g)(convert(AbstractArray{TAB}, A), BB)
13691453
end
13701454
end
13711455
end
@@ -1374,7 +1458,9 @@ for (f, g) in ((:\, :A_ldiv_B!), (:Ac_ldiv_B, :Ac_ldiv_B!), (:At_ldiv_B, :At_ldi
13741458
@eval begin
13751459
function ($f)(A::Union{UpperTriangular,LowerTriangular}, B::$mat)
13761460
TAB = typeof((zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))/one(eltype(A)))
1377-
($g)(convert(AbstractArray{TAB}, A), copy_oftype(B, TAB))
1461+
BB = similar(B, TAB, size(B))
1462+
copy!(BB, B)
1463+
($g)(convert(AbstractArray{TAB}, A), BB)
13781464
end
13791465
end
13801466
end
@@ -1383,16 +1469,20 @@ for (f, g) in ((:*, :A_mul_B!), (:A_mul_Bc, :A_mul_Bc!), (:A_mul_Bt, :A_mul_Bt!)
13831469
@eval begin
13841470
function ($f)(A::$mat, B::AbstractTriangular)
13851471
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
1386-
($g)(copy_oftype(A, TAB), convert(AbstractArray{TAB}, B))
1472+
AA = similar(A, TAB, size(A))
1473+
copy!(AA, A)
1474+
($g)(AA, convert(AbstractArray{TAB}, B))
13871475
end
13881476
end
13891477
end
13901478
### Right division with triangle to the right hence lhs cannot be transposed. No quotients.
13911479
for (f, g) in ((:/, :A_rdiv_B!), (:A_rdiv_Bc, :A_rdiv_Bc!), (:A_rdiv_Bt, :A_rdiv_Bt!))
13921480
@eval begin
1393-
function ($f)(A::$mat, B::Union{UnitUpperTriangular,UnitLowerTriangular})
1481+
function ($f)(A::$mat, B::Tuple{UnitUpperTriangular, UnitLowerTriangular})
13941482
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
1395-
($g)(copy_oftype(A, TAB), convert(AbstractArray{TAB}, B))
1483+
AA = similar(A, TAB, size(A))
1484+
copy!(AA, A)
1485+
($g)(AA, convert(AbstractArray{TAB}, B))
13961486
end
13971487
end
13981488
end
@@ -1402,16 +1492,25 @@ for (f, g) in ((:/, :A_rdiv_B!), (:A_rdiv_Bc, :A_rdiv_Bc!), (:A_rdiv_Bt, :A_rdiv
14021492
@eval begin
14031493
function ($f)(A::$mat, B::Union{UpperTriangular,LowerTriangular})
14041494
TAB = typeof((zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))/one(eltype(A)))
1405-
($g)(copy_oftype(A, TAB), convert(AbstractArray{TAB}, B))
1495+
AA = similar(A, TAB, size(A))
1496+
copy!(AA, A)
1497+
($g)(AA, convert(AbstractArray{TAB}, B))
14061498
end
14071499
end
14081500
end
14091501
end
1410-
### Fallbacks brought in from linalg/bidiag.jl while fixing #14506.
1411-
# Eventually the above promotion methods should be generalized as
1412-
# was done for bidiagonal matrices in #14506.
1413-
At_ldiv_B(A::AbstractTriangular, B::AbstractVecOrMat) = At_ldiv_B!(A, copy(B))
1414-
Ac_ldiv_B(A::AbstractTriangular, B::AbstractVecOrMat) = Ac_ldiv_B!(A, copy(B))
1502+
1503+
# If these are not defined, the they will fallback to the versions in matmul.jl
1504+
# and dispatch to generic_matmatmul! which is very costly to compile. The methods
1505+
# below might compute an unnecessary copy. Eliminating the copy requires adding
1506+
# all the promotion logic here once again. Since these methods are probably relatively
1507+
# rare, we chose not to bother for now.
1508+
Ac_mul_Bc(A::AbstractTriangular, B::AbstractTriangular) = Ac_mul_B(A, B')
1509+
Ac_mul_Bc(A::AbstractTriangular, B::AbstractMatrix) = Ac_mul_B(A, B')
1510+
Ac_mul_Bc(A::AbstractMatrix, B::AbstractTriangular) = A_mul_Bc(A', B)
1511+
At_mul_Bt(A::AbstractTriangular, B::AbstractTriangular) = At_mul_B(A, B.')
1512+
At_mul_Bt(A::AbstractTriangular, B::AbstractMatrix) = At_mul_B(A, B.')
1513+
At_mul_Bt(A::AbstractMatrix, B::AbstractTriangular) = A_mul_Bt(A.', B)
14151514

14161515
# Complex matrix logarithm for the upper triangular factor, see:
14171516
# Al-Mohy and Higham, "Improved inverse scaling and squaring algorithms for

test/linalg/triangular.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ for elty1 in (Float32, Float64, BigFloat, Complex64, Complex128, Complex{BigFloa
2424
# Construct test matrix
2525
A1 = t1(elty1 == Int ? rand(1:7, n, n) : convert(Matrix{elty1}, (elty1 <: Complex ? complex(randn(n, n), randn(n, n)) : randn(n, n)) |> t -> chol(t't) |> t -> uplo1 == :U ? t : ctranspose(t)))
2626

27+
2728
debug && println("elty1: $elty1, A1: $t1")
2829

2930
# Convert

0 commit comments

Comments
 (0)