Skip to content

Commit eeebda6

Browse files
committed
Add methods for left-division operations involving triangular matrices and SparseVectors. Also add tests for those methods.
1 parent e83b755 commit eeebda6

File tree

3 files changed

+186
-1
lines changed

3 files changed

+186
-1
lines changed

base/sparse.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ using Base.Sort: Forward
77
using Base.LinAlg: AbstractTriangular, PosDefException
88

99
import Base: +, -, *, \, &, |, $, .+, .-, .*, ./, .\, .^, .<, .!=, ==
10-
import Base: A_mul_B!, Ac_mul_B, Ac_mul_B!, At_mul_B, At_mul_B!, A_ldiv_B!
10+
import Base: A_mul_B!, Ac_mul_B, Ac_mul_B!, At_mul_B, At_mul_B!, At_ldiv_B, Ac_ldiv_B, A_ldiv_B!
11+
import Base.LinAlg: At_ldiv_B!, Ac_ldiv_B!
1112

1213
import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
1314
atan, atand, atanh, broadcast!, chol, conj!, cos, cosc, cosd, cosh, cospi, cot,

base/sparse/sparsevector.jl

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,3 +1492,124 @@ function _At_or_Ac_mul_B{TvA,TiA,TvX,TiX}(tfun::BinaryOp, A::SparseMatrixCSC{TvA
14921492
end
14931493
SparseVector(n, ynzind, ynzval)
14941494
end
1495+
1496+
# define matrix division operations involving triangular matrices and sparse vectors
1497+
# the valid left-division operations are A[t|c]_ldiv_B[!] and \
1498+
# the valid right-division operations are A(t|c)_rdiv_B[t|c][!]
1499+
# see issue #14005 for discussion of these methods
1500+
for isunittri in (true, false), islowertri in (true, false)
1501+
unitstr = isunittri ? "Unit" : ""
1502+
halfstr = islowertri ? "Lower" : "Upper"
1503+
tritype = :(Base.LinAlg.$(symbol(string(unitstr, halfstr, "Triangular"))))
1504+
1505+
# build out-of-place left-division operations
1506+
for (istrans, func, ipfunc) in (
1507+
(false, :(\), :(A_ldiv_B!)),
1508+
(true, :(At_ldiv_B), :(At_ldiv_B!)),
1509+
(true, :(Ac_ldiv_B), :(Ac_ldiv_B!)) )
1510+
1511+
# broad method
1512+
@eval function ($func){TA,Tb,S}(A::$tritype{TA,S}, b::SparseVector{Tb})
1513+
TAb = $(isunittri ?
1514+
:(typeof(zero(TA)*zero(Tb) + zero(TA)*zero(Tb))) :
1515+
:(typeof((zero(TA)*zero(Tb) + zero(TA)*zero(Tb))/one(TA))) )
1516+
($ipfunc)(convert(AbstractArray{TAb}, A), convert(Array{TAb}, b))
1517+
end
1518+
1519+
# faster method requiring good view support of the
1520+
# triangular matrix type. hence the StridedMatrix restriction.
1521+
@eval function ($func){TA,Tb,S<:StridedMatrix}(A::$tritype{TA,S}, b::SparseVector{Tb})
1522+
TAb = $(isunittri ?
1523+
:(typeof(zero(TA)*zero(Tb) + zero(TA)*zero(Tb))) :
1524+
:(typeof((zero(TA)*zero(Tb) + zero(TA)*zero(Tb))/one(TA))) )
1525+
r = convert(Array{TAb}, b)
1526+
# this operation involves only b[nzrange], so we extract
1527+
# and operate on solely that section for efficiency
1528+
nzrange = $( (islowertri && !istrans) || (!islowertri && istrans) ?
1529+
:(b.nzind[1]:b.n) :
1530+
:(1:b.nzind[end]) )
1531+
nzrangeviewr = sub(r, nzrange)
1532+
nzrangeviewA = $tritype(sub(A.data, nzrange, nzrange))
1533+
($ipfunc)(convert(AbstractArray{TAb}, nzrangeviewA), nzrangeviewr)
1534+
r
1535+
end
1536+
end
1537+
1538+
# build in-place left-division operations
1539+
for (istrans, func) in (
1540+
(false, :(A_ldiv_B!)),
1541+
(true, :(At_ldiv_B!)),
1542+
(true, :(Ac_ldiv_B!)) )
1543+
1544+
# the generic in-place left-division methods handle these cases, but
1545+
# we can achieve greater efficiency where the triangular matrix provides
1546+
# good view support. hence the StridedMatrix restriction.
1547+
@eval function ($func){TA,Tb,S<:StridedMatrix}(A::$tritype{TA,S}, b::SparseVector{Tb})
1548+
# densify the relevant part of b in one shot rather
1549+
# than potentially repeatedly reallocating during the solve
1550+
$( (islowertri && !istrans) || (!islowertri && istrans) ?
1551+
:(_densifyfirstnztoend!(b)) :
1552+
:(_densifystarttolastnz!(b)) )
1553+
# this operation involves only the densified section, so
1554+
# for efficiency we extract and operate on solely that section
1555+
# furthermore we operate on that section as a dense vector
1556+
# such that dispatch has a chance to exploit, e.g., tuned BLAS
1557+
nzrange = $( (islowertri && !istrans) || (!islowertri && istrans) ?
1558+
:(b.nzind[1]:b.n) :
1559+
:(1:b.nzind[end]) )
1560+
nzrangeviewbnz = sub(b.nzval, nzrange - b.nzind[1] + 1)
1561+
nzrangeviewA = $tritype(sub(A.data, nzrange, nzrange))
1562+
($func)(nzrangeviewA, nzrangeviewbnz)
1563+
# could strip any miraculous zeros here perhaps
1564+
b
1565+
end
1566+
end
1567+
end
1568+
1569+
# helper functions for in-place matrix division operations defined above
1570+
"Densifies `x::SparseVector` from its first nonzero (`x[x.nzind[1]]`) through its end (`x[x.n]`)."
1571+
function _densifyfirstnztoend!{Tv,Ti}(x::SparseVector{Tv,Ti})
1572+
# lengthen containers
1573+
oldnnz = nnz(x)
1574+
newnnz = x.n - x.nzind[1] + 1
1575+
resize!(x.nzval, newnnz)
1576+
resize!(x.nzind, newnnz)
1577+
# redistribute nonzero values over lengthened container
1578+
# initialize now-allocated zero values simultaneously
1579+
nextpos = newnnz
1580+
for oldpos in oldnnz:-1:1
1581+
nzi = x.nzind[oldpos]
1582+
nzv = x.nzval[oldpos]
1583+
newpos = nzi - x.nzind[1] + 1
1584+
newpos < nextpos && (x.nzval[newpos+1:nextpos] = 0)
1585+
newpos == oldpos && break
1586+
x.nzval[newpos] = nzv
1587+
nextpos = newpos - 1
1588+
end
1589+
# finally update lengthened nzinds
1590+
x.nzind[2:end] = (x.nzind[1]+1):x.n
1591+
x
1592+
end
1593+
"Densifies `x::SparseVector` from its beginning (`x[1]`) through its last nonzero (`x[x.nzind[end]]`)."
1594+
function _densifystarttolastnz!(x::SparseVector)
1595+
# lengthen containers
1596+
oldnnz = nnz(x)
1597+
newnnz = x.nzind[end]
1598+
resize!(x.nzval, newnnz)
1599+
resize!(x.nzind, newnnz)
1600+
# redistribute nonzero values over lengthened container
1601+
# initialize now-allocated zero values simultaneously
1602+
nextpos = newnnz
1603+
for oldpos in oldnnz:-1:1
1604+
nzi = x.nzind[oldpos]
1605+
nzv = x.nzval[oldpos]
1606+
nzi < nextpos && (x.nzval[nzi+1:nextpos] = 0)
1607+
nzi == oldpos && (nextpos = 0; break)
1608+
x.nzval[nzi] = nzv
1609+
nextpos = nzi - 1
1610+
end
1611+
nextpos > 0 && (x.nzval[1:nextpos] = 0)
1612+
# finally update lengthened nzinds
1613+
x.nzind[1:newnnz] = 1:newnnz
1614+
x
1615+
end

test/sparsedir/sparsevector.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,69 @@ let A = complex(sprandn(7, 8, 0.5), sprandn(7, 8, 0.5)),
724724
@test_approx_eq full(y) Af'x2f
725725
end
726726

727+
# left-division operations involving triangular matrices and sparse vectors (#14005)
728+
let m = 10
729+
sprvec = sprand(m, 0.4)
730+
sparsefloatvec = sprvec
731+
sparseintvec = SparseVector(m, sprvec.nzind, round(Int, sprvec.nzval*10))
732+
sparsecomplexvec = SparseVector(m, sprvec.nzind, complex(sprvec.nzval, sprvec.nzval))
733+
734+
sprmat = sprand(m, m, 0.2)
735+
sparsefloatmat = speye(m) + sprmat/(2m)
736+
sparsecomplexmat = speye(m) + SparseMatrixCSC(m, m, sprmat.colptr, sprmat.rowval, complex(sprmat.nzval, sprmat.nzval)/(4m))
737+
sparseintmat = speye(Int, m)*10m + SparseMatrixCSC(m, m, sprmat.colptr, sprmat.rowval, round(Int, sprmat.nzval*10))
738+
739+
denseintmat = eye(Int, m)*10m + rand(1:m, m, m)
740+
densefloatmat = eye(m) + randn(m, m)/(2m)
741+
densecomplexmat = eye(m) + complex(randn(m, m), randn(m, m))/(4m)
742+
743+
inttypes = (Int32, Int64, BigInt)
744+
floattypes = (Float32, Float64, BigFloat)
745+
complextypes = (Complex64, Complex128)
746+
eltypes = (inttypes..., floattypes..., complextypes...)
747+
748+
for eltypemat in eltypes
749+
eltypemat in inttypes && ((densemat, sparsemat) = (denseintmat, sparseintmat))
750+
eltypemat in floattypes && ((densemat, sparsemat) = (densefloatmat, sparsefloatmat))
751+
eltypemat in complextypes && ((densemat, sparsemat) = (densecomplexmat, sparsecomplexmat))
752+
densemat = convert(Matrix{eltypemat}, densemat)
753+
sparsemat = convert(SparseMatrixCSC{eltypemat}, sparsemat)
754+
trimats = (
755+
LowerTriangular(densemat),
756+
UpperTriangular(densemat),
757+
LowerTriangular(sparsemat),
758+
UpperTriangular(sparsemat) )
759+
unittrimats = (
760+
Base.LinAlg.UnitLowerTriangular(densemat),
761+
Base.LinAlg.UnitUpperTriangular(densemat),
762+
Base.LinAlg.UnitLowerTriangular(sparsemat),
763+
Base.LinAlg.UnitUpperTriangular(sparsemat) )
764+
765+
for eltypevec in eltypes
766+
eltypevec in inttypes && (spvec = sparseintvec)
767+
eltypevec in floattypes && (spvec = sparsefloatvec)
768+
eltypevec in complextypes && (spvec = sparsecomplexvec)
769+
spvec = SparseVector(m, spvec.nzind, convert(Vector{eltypevec}, spvec.nzval))
770+
fspvec = convert(Array, spvec)
771+
772+
# test out-of-place left-division methods
773+
for mat in (trimats..., unittrimats...), func in (\, At_ldiv_B, Ac_ldiv_B)
774+
@test isapprox((func)(mat, spvec), (func)(mat, fspvec))
775+
end
776+
# test in-place left-division methods not involving quotients
777+
eltypevec == typeof(zero(eltypemat)*zero(eltypevec) + zero(eltypemat)*zero(eltypevec)) &&
778+
for mat in unittrimats, func in (A_ldiv_B!, Base.LinAlg.At_ldiv_B!, Base.LinAlg.Ac_ldiv_B!)
779+
@test isapprox((func)(mat, copy(spvec)), (func)(mat, copy(fspvec)))
780+
end
781+
# test in-place left-division methods involving quotients
782+
eltypevec == typeof((zero(eltypemat)*zero(eltypevec) + zero(eltypemat)*zero(eltypevec))/one(eltypemat)) &&
783+
for mat in trimats, func in (A_ldiv_B!, Base.LinAlg.At_ldiv_B!, Base.LinAlg.Ac_ldiv_B!)
784+
@test isapprox((func)(mat, copy(spvec)), (func)(mat, copy(fspvec)))
785+
end
786+
end
787+
end
788+
end #let
789+
727790
# It's tempting to share data between a SparseVector and a SparseArrays,
728791
# but if that's done, then modifications to one or the other will cause
729792
# an inconsistent state:

0 commit comments

Comments
 (0)