Skip to content

Commit 0f341d6

Browse files
committed
Remove explicit dependence of sparse broadcast on type inference
Instead of determining the output element type beforehand by querying inference, the element type is deduced from the actually computed output values (similar to broadcast over Array, but taking into account the output for the all-inputs-zero case). For the type-unstable case, performance is sub-optimal, but at least it gives the correct result. Closes #19595.
1 parent 2d8f5bf commit 0f341d6

File tree

2 files changed

+60
-27
lines changed

2 files changed

+60
-27
lines changed

base/sparse/sparsematrix.jl

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,7 +1413,7 @@ function map{Tf,N}(f::Tf, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
14131413
fofzeros = f(_zeros_eltypes(A, Bs...)...)
14141414
fpreszeros = fofzeros == zero(fofzeros)
14151415
maxnnzC = fpreszeros ? min(length(A), _sumnnzs(A, Bs...)) : length(A)
1416-
entrytypeC = _broadcast_type(f, A, Bs...)
1416+
entrytypeC = typeof(fofzeros)
14171417
indextypeC = _promote_indtype(A, Bs...)
14181418
Ccolptr = Vector{indextypeC}(A.n + 1)
14191419
Crowval = Vector{indextypeC}(maxnnzC)
@@ -1438,7 +1438,7 @@ function broadcast{Tf,N}(f::Tf, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N
14381438
fofzeros = f(_zeros_eltypes(A, Bs...)...)
14391439
fpreszeros = fofzeros == zero(fofzeros)
14401440
indextypeC = _promote_indtype(A, Bs...)
1441-
entrytypeC = _broadcast_type(f, A, Bs...)
1441+
entrytypeC = typeof(fofzeros)
14421442
Cm, Cn = Base.to_shape(Base.Broadcast.broadcast_indices(A, Bs...))
14431443
maxnnzC = fpreszeros ? _checked_maxnnzbcres(Cm, Cn, A, Bs...) : (Cm * Cn)
14441444
Ccolptr = Vector{indextypeC}(Cn + 1)
@@ -1464,26 +1464,33 @@ _maxnnzfrom(Cm, Cn, A) = nnz(A) * div(Cm, A.m) * div(Cn, A.n)
14641464
@inline _maxnnzfrom_each(Cm, Cn, As) = (_maxnnzfrom(Cm, Cn, first(As)), _maxnnzfrom_each(Cm, Cn, tail(As))...)
14651465
@inline _unchecked_maxnnzbcres(Cm, Cn, As) = min(Cm * Cn, sum(_maxnnzfrom_each(Cm, Cn, As)))
14661466
@inline _checked_maxnnzbcres(Cm, Cn, As...) = Cm != 0 && Cn != 0 ? _unchecked_maxnnzbcres(Cm, Cn, As) : 0
1467-
_broadcast_type(f, As...) = Base._promote_op(f, Base.Broadcast.typestuple(As...))
1467+
@inline _update_nzval!{T}(nzval::Vector{T}, k, x::T) = (nzval[k] = x; nzval)
1468+
@inline function _update_nzval!{T,Tx}(nzval::Vector{T}, k, x::Tx)
1469+
nzval = convert(Vector{typejoin(Tx, T)}, nzval)
1470+
nzval[k] = x
1471+
return nzval
1472+
end
14681473

14691474
# _map_zeropres!/_map_notzeropres! specialized for a single sparse matrix
14701475
"Stores only the nonzero entries of `map(f, Matrix(A))` in `C`."
14711476
function _map_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC)
14721477
spaceC = min(length(C.rowval), length(C.nzval))
14731478
Ck = 1
1479+
nzval = C.nzval
14741480
@inbounds for j in 1:C.n
14751481
C.colptr[j] = Ck
14761482
for Ak in nzrange(A, j)
14771483
Cx = f(A.nzval[Ak])
14781484
if Cx != zero(eltype(C))
14791485
Ck > spaceC && (spaceC = _expandstorage!(C, Ck + nnz(A) - (Ak - 1)))
14801486
C.rowval[Ck] = A.rowval[Ak]
1481-
C.nzval[Ck] = Cx
1487+
nzval = _update_nzval!(nzval, Ck, Cx)
14821488
Ck += 1
14831489
end
14841490
end
14851491
end
14861492
@inbounds C.colptr[C.n + 1] = Ck
1493+
nzval === C.nzval || (C = SparseMatrixCSC(C.m, C.n, C.colptr, C.rowval, nzval))
14871494
_trimstorage!(C, Ck - 1)
14881495
return C
14891496
end
@@ -1496,13 +1503,14 @@ function _map_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseMatrixCSC, A::SparseMa
14961503
_densestructure!(C)
14971504
# Populate values
14981505
fill!(C.nzval, fillvalue)
1506+
nzval = C.nzval
14991507
@inbounds for (j, jo) in zip(1:C.n, 0:C.m:(C.m*C.n - 1)), Ak in nzrange(A, j)
15001508
Cx = f(A.nzval[Ak])
1501-
Cx != fillvalue && (C.nzval[jo + A.rowval[Ak]] = Cx)
1509+
Cx != fillvalue && (nzval = _update_nzval!(nzval, jo + A.rowval[Ak], Cx))
15021510
end
15031511
# NOTE: Combining the fill! above into the loop above to avoid multiple sweeps over /
15041512
# nonsequential access of C.nzval does not appear to improve performance.
1505-
return C
1513+
return nzval === C.nzval ? C : SparseMatrixCSC(C.m, C.n, C.colptr, C.rowval, nzval)
15061514
end
15071515
# helper functions for these methods and some of those below
15081516
function _expandstorage!(X::SparseMatrixCSC, maxstored)
@@ -1533,6 +1541,7 @@ function _map_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC, B::Sp
15331541
spaceC = min(length(C.rowval), length(C.nzval))
15341542
rowsentinelA = convert(eltype(A.rowval), C.m + 1)
15351543
rowsentinelB = convert(eltype(B.rowval), C.m + 1)
1544+
nzval = C.nzval
15361545
Ck = 1
15371546
@inbounds for j in 1:C.n
15381547
C.colptr[j] = Ck
@@ -1562,12 +1571,13 @@ function _map_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC, B::Sp
15621571
if Cx != zero(eltype(C))
15631572
Ck > spaceC && (spaceC = _expandstorage!(C, Ck + (nnz(A) - (Ak - 1)) + (nnz(B) - (Bk - 1))))
15641573
C.rowval[Ck] = Ci
1565-
C.nzval[Ck] = Cx
1574+
nzval = _update_nzval!(nzval, Ck, Cx)
15661575
Ck += 1
15671576
end
15681577
end
15691578
end
15701579
@inbounds C.colptr[C.n + 1] = Ck
1580+
nzval === C.nzval || (C = SparseMatrixCSC(C.m, C.n, C.colptr, C.rowval, nzval))
15711581
_trimstorage!(C, Ck - 1)
15721582
return C
15731583
end
@@ -1578,6 +1588,7 @@ function _map_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseMatrixCSC, A::SparseMa
15781588
fill!(C.nzval, fillvalue)
15791589
# NOTE: Combining this fill! into the loop below to avoid multiple sweeps over /
15801590
# nonsequential access of C.nzval does not appear to improve performance.
1591+
nzval = C.nzval
15811592
rowsentinelA = convert(eltype(A.rowval), C.m + 1)
15821593
rowsentinelB = convert(eltype(B.rowval), C.m + 1)
15831594
@inbounds for (j, jo) in zip(1:C.n, 0:C.m:(C.m*C.n - 1))
@@ -1598,17 +1609,18 @@ function _map_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseMatrixCSC, A::SparseMa
15981609
Cx, Ci = f(zero(eltype(A)), B.nzval[Bk]), Bi
15991610
Bk += one(Bk); Bi = Bk < stopBk ? B.rowval[Bk] : rowsentinelB
16001611
end
1601-
Cx != fillvalue && (C.nzval[jo + Ci] = Cx)
1612+
Cx != fillvalue && (nzval = _update_nzval!(nzval, jo + Ci, Cx))
16021613
end
16031614
end
1604-
return C
1615+
return nzval === C.nzval ? C : SparseMatrixCSC(C.m, C.n, C.colptr, C.rowval, nzval)
16051616
end
16061617
# _broadcast_zeropres!/_broadcast_notzeropres! specialized for a pair of (input) sparse matrices
16071618
function _broadcast_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC, B::SparseMatrixCSC)
16081619
isempty(C) && (fill!(C.colptr, 1); return C)
16091620
spaceC = min(length(C.rowval), length(C.nzval))
16101621
rowsentinelA = convert(eltype(A.rowval), C.m + 1)
16111622
rowsentinelB = convert(eltype(B.rowval), C.m + 1)
1623+
nzval = C.nzval
16121624
# A and B cannot have the same shape, as we directed that case to map in broadcast's
16131625
# entry point; here we need efficiently handle only heterogeneous combinations of matrices
16141626
# with no singleton dimensions ("matrices" hereafter), one singleton dimension ("columns"
@@ -1663,7 +1675,7 @@ function _broadcast_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC,
16631675
if Cx != zero(eltype(C))
16641676
Ck > spaceC && (spaceC = _expandstorage!(C, _unchecked_maxnnzbcres(C.m, C.n, A, B)))
16651677
C.rowval[Ck] = Ci
1666-
C.nzval[Ck] = Cx
1678+
nzval = _update_nzval!(nzval, Ck, Cx)
16671679
Ck += 1
16681680
end
16691681
end
@@ -1685,7 +1697,7 @@ function _broadcast_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC,
16851697
if Cx != zero(eltype(C))
16861698
Ck > spaceC && (spaceC = _expandstorage!(C, _unchecked_maxnnzbcres(C.m, C.n, A, B)))
16871699
C.rowval[Ck] = B.rowval[Bk]
1688-
C.nzval[Ck] = Cx
1700+
nzval = _update_nzval!(nzval, Ck, Cx)
16891701
Ck += 1
16901702
end
16911703
Bk += one(Bk)
@@ -1704,7 +1716,7 @@ function _broadcast_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC,
17041716
if Cx != zero(eltype(C))
17051717
Ck > spaceC && (spaceC = _expandstorage!(C, _unchecked_maxnnzbcres(C.m, C.n, A, B)))
17061718
C.rowval[Ck] = Ci
1707-
C.nzval[Ck] = Cx
1719+
nzval = _update_nzval!(nzval, Ck, Cx)
17081720
Ck += 1
17091721
end
17101722
end
@@ -1726,7 +1738,7 @@ function _broadcast_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC,
17261738
if Cx != zero(eltype(C))
17271739
Ck > spaceC && (spaceC = _expandstorage!(C, _unchecked_maxnnzbcres(C.m, C.n, A, B)))
17281740
C.rowval[Ck] = A.rowval[Ak]
1729-
C.nzval[Ck] = Cx
1741+
nzval = _update_nzval!(nzval, Ck, Cx)
17301742
Ck += 1
17311743
end
17321744
Ak += one(Ak)
@@ -1745,14 +1757,15 @@ function _broadcast_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC,
17451757
if Cx != zero(eltype(C))
17461758
Ck > spaceC && (spaceC = _expandstorage!(C, _unchecked_maxnnzbcres(C.m, C.n, A, B)))
17471759
C.rowval[Ck] = Ci
1748-
C.nzval[Ck] = Cx
1760+
nzval = _update_nzval!(nzval, Ck, Cx)
17491761
Ck += 1
17501762
end
17511763
end
17521764
end
17531765
end
17541766
end
17551767
@inbounds C.colptr[C.n + 1] = Ck
1768+
nzval === C.nzval || (C = SparseMatrixCSC(C.m, C.n, C.colptr, C.rowval, nzval))
17561769
_trimstorage!(C, Ck - 1)
17571770
return C
17581771
end
@@ -1764,6 +1777,7 @@ function _broadcast_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseMatrixCSC, A::Sp
17641777
fill!(C.nzval, fillvalue)
17651778
rowsentinelA = convert(eltype(A.rowval), C.m + 1)
17661779
rowsentinelB = convert(eltype(B.rowval), C.m + 1)
1780+
nzval = C.nzval
17671781
# Cases without vertical expansion
17681782
if A.m == B.m
17691783
@inbounds for (j, jo) in zip(1:C.n, 0:C.m:(C.m*C.n - 1))
@@ -1785,7 +1799,7 @@ function _broadcast_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseMatrixCSC, A::Sp
17851799
Ak += one(Ak); Ai = Ak < stopAk ? A.rowval[Ak] : rowsentinelA
17861800
Bk += one(Bk); Bi = Bk < stopBk ? B.rowval[Bk] : rowsentinelB
17871801
end
1788-
Cx != fillvalue && (C.nzval[jo + Ci] = Cx)
1802+
Cx != fillvalue && (nzval = _update_nzval!(nzval, jo + Ci, Cx))
17891803
end
17901804
end
17911805
# Cases with vertical expansion
@@ -1798,7 +1812,7 @@ function _broadcast_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseMatrixCSC, A::Sp
17981812
if fvAzB == zero(eltype(C))
17991813
while Bk < stopBk
18001814
Cx = f(Ax, B.nzval[Bk])
1801-
Cx != fillvalue && (C.nzval[jo + B.rowval[Bk]] = Cx)
1815+
Cx != fillvalue && (nzval = _update_nzval!(nzval, jo + B.rowval[Bk], Cx))
18021816
Bk += one(Bk)
18031817
end
18041818
else
@@ -1810,7 +1824,7 @@ function _broadcast_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseMatrixCSC, A::Sp
18101824
else
18111825
Cx = fvAzB
18121826
end
1813-
Cx != fillvalue && (C.nzval[jo + Ci] = Cx)
1827+
Cx != fillvalue && (nzval = _update_nzval!(nzval, jo + Ci, Cx))
18141828
end
18151829
end
18161830
end
@@ -1823,7 +1837,7 @@ function _broadcast_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseMatrixCSC, A::Sp
18231837
if fzAvB == zero(eltype(C))
18241838
while Ak < stopAk
18251839
Cx = f(A.nzval[Ak], Bx)
1826-
Cx != fillvalue && (C.nzval[jo + A.rowval[Ak]] = Cx)
1840+
Cx != fillvalue && (nzval = _update_nzval!(nzval, jo + A.rowval[Ak], Cx))
18271841
Ak += one(Ak)
18281842
end
18291843
else
@@ -1835,12 +1849,12 @@ function _broadcast_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseMatrixCSC, A::Sp
18351849
else
18361850
Cx = fzAvB
18371851
end
1838-
Cx != fillvalue && (C.nzval[jo + Ci] = Cx)
1852+
Cx != fillvalue && (nzval = _update_nzval!(nzval, jo + Ci, Cx))
18391853
end
18401854
end
18411855
end
18421856
end
1843-
return C
1857+
return nzval === C.nzval ? C : SparseMatrixCSC(C.m, C.n, C.colptr, C.rowval, nzval)
18441858
end
18451859

18461860
# _map_zeropres!/_map_notzeropres! for more than two sparse matrices
@@ -1849,6 +1863,7 @@ function _map_zeropres!{Tf,N}(f::Tf, C::SparseMatrixCSC, As::Vararg{SparseMatrix
18491863
rowsentinel = C.m + 1
18501864
Ck = 1
18511865
stopks = _indforcol_all(1, As)
1866+
nzval = C.nzval
18521867
@inbounds for j in 1:C.n
18531868
C.colptr[j] = Ck
18541869
ks = stopks
@@ -1865,13 +1880,14 @@ function _map_zeropres!{Tf,N}(f::Tf, C::SparseMatrixCSC, As::Vararg{SparseMatrix
18651880
if Cx != zero(eltype(C))
18661881
Ck > spaceC && (spaceC = _expandstorage!(C, Ck + min(length(C), _sumnnzs(As...)) - (sum(ks) - N)))
18671882
C.rowval[Ck] = activerow
1868-
C.nzval[Ck] = Cx
1883+
_update_nzval!(nzval, Ck, Cx)
18691884
Ck += 1
18701885
end
18711886
activerow = min(rows...)
18721887
end
18731888
end
18741889
@inbounds C.colptr[C.n + 1] = Ck
1890+
nzval === C.nzval || (C = SparseMatrixCSC(C.m, C.n, C.colptr, C.rowval, nzval))
18751891
_trimstorage!(C, Ck - 1)
18761892
return C
18771893
end
@@ -1884,6 +1900,7 @@ function _map_notzeropres!{Tf,N}(f::Tf, fillvalue, C::SparseMatrixCSC, As::Varar
18841900
# nonsequential access of C.nzval does not appear to improve performance.
18851901
rowsentinel = C.m + 1
18861902
stopks = _indforcol_all(1, As)
1903+
nzval = C.nzval
18871904
@inbounds for (j, jo) in zip(1:C.n, 0:C.m:(C.m*C.n - 1))
18881905
ks = stopks
18891906
stopks = _indforcol_all(j + 1, As)
@@ -1896,7 +1913,7 @@ function _map_notzeropres!{Tf,N}(f::Tf, fillvalue, C::SparseMatrixCSC, As::Varar
18961913
# rows = _updaterow_all(rowsentinel, activerows, rows, ks, stopks, As)
18971914
vals, ks, rows = _fusedupdate_all(rowsentinel, activerow, rows, ks, stopks, As)
18981915
Cx = f(vals...)
1899-
Cx != fillvalue && (C.nzval[jo + activerow] = Cx)
1916+
Cx != fillvalue && (_update_nzval!(nzval, jo + activerow, Cx))
19001917
activerow = min(rows...)
19011918
end
19021919
end
@@ -1962,6 +1979,7 @@ function _broadcast_zeropres!{Tf,N}(f::Tf, C::SparseMatrixCSC, As::Vararg{Sparse
19621979
expandsverts = _expandsvert_all(C, As)
19631980
expandshorzs = _expandshorz_all(C, As)
19641981
rowsentinel = C.m + 1
1982+
nzval = C.nzval
19651983
Ck = 1
19661984
@inbounds for j in 1:C.n
19671985
C.colptr[j] = Ck
@@ -1985,7 +2003,7 @@ function _broadcast_zeropres!{Tf,N}(f::Tf, C::SparseMatrixCSC, As::Vararg{Sparse
19852003
if Cx != zero(eltype(C))
19862004
Ck > spaceC && (spaceC = _expandstorage!(C, _unchecked_maxnnzbcres(C.m, C.n, As)))
19872005
C.rowval[Ck] = activerow
1988-
C.nzval[Ck] = Cx
2006+
nzval = _update_nzval!(nzval, Ck, Cx)
19892007
Ck += 1
19902008
end
19912009
activerow = min(rows...)
@@ -2006,13 +2024,14 @@ function _broadcast_zeropres!{Tf,N}(f::Tf, C::SparseMatrixCSC, As::Vararg{Sparse
20062024
if Cx != zero(eltype(C))
20072025
Ck > spaceC && (spaceC = _expandstorage!(C, _unchecked_maxnnzbcres(C.m, C.n, As)))
20082026
C.rowval[Ck] = Ci
2009-
C.nzval[Ck] = Cx
2027+
nzval = _update_nzval!(nzval, Ck, Cx)
20102028
Ck += 1
20112029
end
20122030
end
20132031
end
20142032
end
20152033
@inbounds C.colptr[C.n + 1] = Ck
2034+
nzval === C.nzval || (C = SparseMatrixCSC(C.m, C.n, C.colptr, C.rowval, nzval))
20162035
_trimstorage!(C, Ck - 1)
20172036
return C
20182037
end
@@ -2022,6 +2041,7 @@ function _broadcast_notzeropres!{Tf,N}(f::Tf, fillvalue, C::SparseMatrixCSC, As:
20222041
_densestructure!(C)
20232042
# Populate values
20242043
fill!(C.nzval, fillvalue)
2044+
nzval = C.nzval
20252045
expandsverts = _expandsvert_all(C, As)
20262046
expandshorzs = _expandshorz_all(C, As)
20272047
rowsentinel = C.m + 1
@@ -2043,7 +2063,7 @@ function _broadcast_notzeropres!{Tf,N}(f::Tf, fillvalue, C::SparseMatrixCSC, As:
20432063
# rows = _updaterow_all(rowsentinel, activerows, rows, ks, stopks, As)
20442064
args, ks, rows = _fusedupdatebc_all(rowsentinel, activerow, rows, defargs, ks, stopks, As)
20452065
Cx = f(args...)
2046-
Cx != fillvalue && (C.nzval[jo + activerow] = Cx)
2066+
Cx != fillvalue && (nzval = _update_nzval!(nzval, jo + activerow, Cx))
20472067
activerow = min(rows...)
20482068
end
20492069
else # fillvalue-non-preserving column scan
@@ -2059,11 +2079,11 @@ function _broadcast_notzeropres!{Tf,N}(f::Tf, fillvalue, C::SparseMatrixCSC, As:
20592079
else
20602080
Cx = defaultCx
20612081
end
2062-
Cx != fillvalue && (C.nzval[jo + Ci] = Cx)
2082+
Cx != fillvalue && (nzval = _update_nzval!(nzval, jo + Ci, Cx))
20632083
end
20642084
end
20652085
end
2066-
return C
2086+
return nzval === C.nzval ? C : SparseMatrixCSC(C.m, C.n, C.colptr, C.rowval, nzval)
20672087
end
20682088
# helper method for broadcast/broadcast! methods just above
20692089
@inline _expandsvert(C, A) = A.m != C.m

test/sparse/sparse.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,3 +1828,16 @@ let
18281828
@test_throws DimensionMismatch broadcast(+, A, B, speye(N))
18291829
@test_throws DimensionMismatch broadcast!(+, X, A, B, speye(N))
18301830
end
1831+
1832+
# Issue #19595 - broadcasting over sparse matrices with abstract eltype
1833+
let x = sparse(eye(Real,3,3))
1834+
@test eltype(x) === Real
1835+
@test eltype(x + x) <: Real
1836+
@test eltype(x .+ x) <: Real
1837+
@test eltype(map(+, x, x)) <: Real
1838+
@test eltype(broadcast(+, x, x)) <: Real
1839+
@test eltype(x + x + x) <: Real
1840+
@test eltype(x .+ x .+ x) <: Real
1841+
@test eltype(map(+, map(+, x, x), x)) <: Real
1842+
@test eltype(broadcast(+, broadcast(+, x, x), x)) <: Real
1843+
end

0 commit comments

Comments
 (0)