Skip to content

Commit 49c970f

Browse files
committed
Add methods for left-division operations involving triangular matrices and SparseVectors. Also add tests for those methods.
1 parent 493f913 commit 49c970f

File tree

3 files changed

+186
-1
lines changed

3 files changed

+186
-1
lines changed

base/sparse.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ using Base.Sort: Forward
77
using Base.LinAlg: AbstractTriangular, PosDefException
88

99
import Base: +, -, *, \, &, |, $, .+, .-, .*, ./, .\, .^, .<, .!=, ==
10-
import Base: A_mul_B!, Ac_mul_B, Ac_mul_B!, At_mul_B, At_mul_B!, A_ldiv_B!
10+
import Base: A_mul_B!, Ac_mul_B, Ac_mul_B!, At_mul_B, At_mul_B!, At_ldiv_B, Ac_ldiv_B, A_ldiv_B!
11+
import Base.LinAlg: At_ldiv_B!, Ac_ldiv_B!
1112

1213
import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
1314
atan, atand, atanh, broadcast!, chol, conj!, cos, cosc, cosd, cosh, cospi, cot,

base/sparse/sparsevector.jl

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,3 +1492,127 @@ function _At_or_Ac_mul_B{TvA,TiA,TvX,TiX}(tfun::BinaryOp, A::SparseMatrixCSC{TvA
14921492
end
14931493
SparseVector(n, ynzind, ynzval)
14941494
end
1495+
1496+
# define matrix division operations involving triangular matrices and sparse vectors
1497+
# the valid left-division operations are A[t|c]_ldiv_B[!] and \
1498+
# the valid right-division operations are A(t|c)_rdiv_B[t|c][!]
1499+
# see issue #14005 for discussion of these methods
1500+
for isunittri in (true, false), islowertri in (true, false)
1501+
unitstr = isunittri ? "Unit" : ""
1502+
halfstr = islowertri ? "Lower" : "Upper"
1503+
tritype = :(Base.LinAlg.$(symbol(string(unitstr, halfstr, "Triangular"))))
1504+
1505+
# build out-of-place left-division operations
1506+
for (istrans, func, ipfunc) in (
1507+
(false, :(\), :(A_ldiv_B!)),
1508+
(true, :(At_ldiv_B), :(At_ldiv_B!)),
1509+
(true, :(Ac_ldiv_B), :(Ac_ldiv_B!)) )
1510+
1511+
# broad method where elements are Numbers
1512+
@eval function ($func){TA<:Number,Tb<:Number,S}(A::$tritype{TA,S}, b::SparseVector{Tb})
1513+
TAb = $(isunittri ?
1514+
:(typeof(zero(TA)*zero(Tb) + zero(TA)*zero(Tb))) :
1515+
:(typeof((zero(TA)*zero(Tb) + zero(TA)*zero(Tb))/one(TA))) )
1516+
($ipfunc)(convert(AbstractArray{TAb}, A), convert(Array{TAb}, b))
1517+
end
1518+
1519+
# faster method requiring good view support of the
1520+
# triangular matrix type. hence the StridedMatrix restriction.
1521+
@eval function ($func){TA<:Number,Tb<:Number,S<:StridedMatrix}(A::$tritype{TA,S}, b::SparseVector{Tb})
1522+
TAb = $(isunittri ?
1523+
:(typeof(zero(TA)*zero(Tb) + zero(TA)*zero(Tb))) :
1524+
:(typeof((zero(TA)*zero(Tb) + zero(TA)*zero(Tb))/one(TA))) )
1525+
r = convert(Array{TAb}, b)
1526+
# this operation involves only b[nzrange], so we extract
1527+
# and operate on solely that section for efficiency
1528+
nzrange = $( (islowertri && !istrans) || (!islowertri && istrans) ?
1529+
:(b.nzind[1]:b.n) :
1530+
:(1:b.nzind[end]) )
1531+
nzrangeviewr = sub(r, nzrange)
1532+
nzrangeviewA = $tritype(sub(A.data, nzrange, nzrange))
1533+
($ipfunc)(convert(AbstractArray{TAb}, nzrangeviewA), nzrangeviewr)
1534+
r
1535+
end
1536+
1537+
# fallback where elements are not Numbers
1538+
@eval ($func){TA,Tb,S}(A::$tritype{TA,S}, b::SparseVector{Tb}) = ($ipfunc)(A, copy(b))
1539+
end
1540+
1541+
# build in-place left-division operations
1542+
for (istrans, func) in (
1543+
(false, :(A_ldiv_B!)),
1544+
(true, :(At_ldiv_B!)),
1545+
(true, :(Ac_ldiv_B!)) )
1546+
1547+
# the generic in-place left-division methods handle these cases, but
1548+
# we can achieve greater efficiency where the triangular matrix provides
1549+
# good view support. hence the StridedMatrix restriction.
1550+
@eval function ($func){TA,Tb,S<:StridedMatrix}(A::$tritype{TA,S}, b::SparseVector{Tb})
1551+
# densify the relevant part of b in one shot rather
1552+
# than potentially repeatedly reallocating during the solve
1553+
$( (islowertri && !istrans) || (!islowertri && istrans) ?
1554+
:(_densifyfirstnztoend!(b)) :
1555+
:(_densifystarttolastnz!(b)) )
1556+
# this operation involves only the densified section, so
1557+
# for efficiency we extract and operate on solely that section
1558+
# furthermore we operate on that section as a dense vector
1559+
# such that dispatch has a chance to exploit, e.g., tuned BLAS
1560+
nzrange = $( (islowertri && !istrans) || (!islowertri && istrans) ?
1561+
:(b.nzind[1]:b.n) :
1562+
:(1:b.nzind[end]) )
1563+
nzrangeviewbnz = sub(b.nzval, nzrange - b.nzind[1] + 1)
1564+
nzrangeviewA = $tritype(sub(A.data, nzrange, nzrange))
1565+
($func)(nzrangeviewA, nzrangeviewbnz)
1566+
# could strip any miraculous zeros here perhaps
1567+
b
1568+
end
1569+
end
1570+
end
1571+
1572+
# helper functions for in-place matrix division operations defined above
1573+
"Densifies `x::SparseVector` from its first nonzero (`x[x.nzind[1]]`) through its end (`x[x.n]`)."
1574+
function _densifyfirstnztoend!(x::SparseVector)
1575+
# lengthen containers
1576+
oldnnz = nnz(x)
1577+
newnnz = x.n - x.nzind[1] + 1
1578+
resize!(x.nzval, newnnz)
1579+
resize!(x.nzind, newnnz)
1580+
# redistribute nonzero values over lengthened container
1581+
# initialize now-allocated zero values simultaneously
1582+
nextpos = newnnz
1583+
@inbounds for oldpos in oldnnz:-1:1
1584+
nzi = x.nzind[oldpos]
1585+
nzv = x.nzval[oldpos]
1586+
newpos = nzi - x.nzind[1] + 1
1587+
newpos < nextpos && (x.nzval[newpos+1:nextpos] = 0)
1588+
newpos == oldpos && break
1589+
x.nzval[newpos] = nzv
1590+
nextpos = newpos - 1
1591+
end
1592+
# finally update lengthened nzinds
1593+
x.nzind[2:end] = (x.nzind[1]+1):x.n
1594+
x
1595+
end
1596+
"Densifies `x::SparseVector` from its beginning (`x[1]`) through its last nonzero (`x[x.nzind[end]]`)."
1597+
function _densifystarttolastnz!(x::SparseVector)
1598+
# lengthen containers
1599+
oldnnz = nnz(x)
1600+
newnnz = x.nzind[end]
1601+
resize!(x.nzval, newnnz)
1602+
resize!(x.nzind, newnnz)
1603+
# redistribute nonzero values over lengthened container
1604+
# initialize now-allocated zero values simultaneously
1605+
nextpos = newnnz
1606+
@inbounds for oldpos in oldnnz:-1:1
1607+
nzi = x.nzind[oldpos]
1608+
nzv = x.nzval[oldpos]
1609+
nzi < nextpos && (x.nzval[nzi+1:nextpos] = 0)
1610+
nzi == oldpos && (nextpos = 0; break)
1611+
x.nzval[nzi] = nzv
1612+
nextpos = nzi - 1
1613+
end
1614+
nextpos > 0 && (x.nzval[1:nextpos] = 0)
1615+
# finally update lengthened nzinds
1616+
x.nzind[1:newnnz] = 1:newnnz
1617+
x
1618+
end

test/sparsedir/sparsevector.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,66 @@ let A = complex(sprandn(7, 8, 0.5), sprandn(7, 8, 0.5)),
724724
@test_approx_eq full(y) Af'x2f
725725
end
726726

727+
# left-division operations involving triangular matrices and sparse vectors (#14005)
728+
let m = 10
729+
sparsefloatvecs = SparseVector[sprand(m, 0.4) for k in 1:3]
730+
sparseintvecs = SparseVector[SparseVector(m, sprvec.nzind, round(Int, sprvec.nzval*10)) for sprvec in sparsefloatvecs]
731+
sparsecomplexvecs = SparseVector[SparseVector(m, sprvec.nzind, complex(sprvec.nzval, sprvec.nzval)) for sprvec in sparsefloatvecs]
732+
733+
sprmat = sprand(m, m, 0.2)
734+
sparsefloatmat = speye(m) + sprmat/(2m)
735+
sparsecomplexmat = speye(m) + SparseMatrixCSC(m, m, sprmat.colptr, sprmat.rowval, complex(sprmat.nzval, sprmat.nzval)/(4m))
736+
sparseintmat = speye(Int, m)*10m + SparseMatrixCSC(m, m, sprmat.colptr, sprmat.rowval, round(Int, sprmat.nzval*10))
737+
738+
denseintmat = eye(Int, m)*10m + rand(1:m, m, m)
739+
densefloatmat = eye(m) + randn(m, m)/(2m)
740+
densecomplexmat = eye(m) + complex(randn(m, m), randn(m, m))/(4m)
741+
742+
inttypes = (Int32, Int64, BigInt)
743+
floattypes = (Float32, Float64, BigFloat)
744+
complextypes = (Complex{Float32}, Complex{Float64})
745+
eltypes = (inttypes..., floattypes..., complextypes...)
746+
747+
for eltypemat in eltypes
748+
(densemat, sparsemat) = eltypemat in inttypes ? (denseintmat, sparseintmat) :
749+
eltypemat in floattypes ? (densefloatmat, sparsefloatmat) :
750+
eltypemat in complextypes && (densecomplexmat, sparsecomplexmat)
751+
densemat = convert(Matrix{eltypemat}, densemat)
752+
sparsemat = convert(SparseMatrixCSC{eltypemat}, sparsemat)
753+
trimats = (LowerTriangular(densemat), UpperTriangular(densemat),
754+
LowerTriangular(sparsemat), UpperTriangular(sparsemat) )
755+
unittrimats = (Base.LinAlg.UnitLowerTriangular(densemat), Base.LinAlg.UnitUpperTriangular(densemat),
756+
Base.LinAlg.UnitLowerTriangular(sparsemat), Base.LinAlg.UnitUpperTriangular(sparsemat) )
757+
758+
for eltypevec in eltypes
759+
spvecs = eltypevec in inttypes ? sparseintvecs :
760+
eltypevec in floattypes ? sparsefloatvecs :
761+
eltypevec in complextypes && sparsecomplexvecs
762+
spvecs = SparseVector[SparseVector(m, spvec.nzind, convert(Vector{eltypevec}, spvec.nzval)) for spvec in spvecs]
763+
764+
for spvec in spvecs
765+
fspvec = convert(Array, spvec)
766+
# test out-of-place left-division methods
767+
for mat in (trimats..., unittrimats...), func in (\, At_ldiv_B, Ac_ldiv_B)
768+
@test isapprox((func)(mat, spvec), (func)(mat, fspvec))
769+
end
770+
# test in-place left-division methods not involving quotients
771+
if eltypevec == typeof(zero(eltypemat)*zero(eltypevec) + zero(eltypemat)*zero(eltypevec))
772+
for mat in unittrimats, func in (A_ldiv_B!, Base.LinAlg.At_ldiv_B!, Base.LinAlg.Ac_ldiv_B!)
773+
@test isapprox((func)(mat, copy(spvec)), (func)(mat, copy(fspvec)))
774+
end
775+
end
776+
# test in-place left-division methods involving quotients
777+
if eltypevec == typeof((zero(eltypemat)*zero(eltypevec) + zero(eltypemat)*zero(eltypevec))/one(eltypemat))
778+
for mat in trimats, func in (A_ldiv_B!, Base.LinAlg.At_ldiv_B!, Base.LinAlg.Ac_ldiv_B!)
779+
@test isapprox((func)(mat, copy(spvec)), (func)(mat, copy(fspvec)))
780+
end
781+
end
782+
end
783+
end
784+
end
785+
end
786+
727787
# It's tempting to share data between a SparseVector and a SparseArrays,
728788
# but if that's done, then modifications to one or the other will cause
729789
# an inconsistent state:

0 commit comments

Comments
 (0)