Skip to content

Commit 8dab405

Browse files
committed
Make sparse operations less dependent on inference
1 parent 1362268 commit 8dab405

File tree

1 file changed

+51
-39
lines changed

1 file changed

+51
-39
lines changed

base/sparse/sparsematrix.jl

Lines changed: 51 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,52 +1401,65 @@ sparse(S::UniformScaling, m::Integer, n::Integer=m) = speye_scaled(S.λ, m, n)
14011401
## map/map! and broadcast/broadcast! over sparse matrices
14021402

14031403
# map/map! entry points
1404-
function map!{Tf,N}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
1404+
function map!{F,N}(f::F, C::SparseMatrixCSC, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
14051405
_checksameshape(C, A, Bs...)
1406+
return _map!(f, C, A, Bs...)
1407+
end
1408+
@inline function _map!(f, C, A, Bs::Vararg)
14061409
fofzeros = f(_zeros_eltypes(A, Bs...)...)
14071410
fpreszeros = fofzeros == zero(fofzeros)
14081411
return fpreszeros ? _map_zeropres!(f, C, A, Bs...) :
14091412
_map_notzeropres!(f, fofzeros, C, A, Bs...)
14101413
end
1411-
function map{Tf,N}(f::Tf, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
1414+
function map{F,N}(f::F, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
14121415
_checksameshape(A, Bs...)
1413-
fofzeros = f(_zeros_eltypes(A, Bs...)...)
1414-
fpreszeros = fofzeros == zero(fofzeros)
1415-
maxnnzC = fpreszeros ? min(length(A), _sumnnzs(A, Bs...)) : length(A)
1416+
return _map(f, A, Bs...)
1417+
end
1418+
@inline function _map(f, A, Bs::Vararg)
14161419
entrytypeC = _broadcast_type(f, A, Bs...)
1417-
indextypeC = _promote_indtype(A, Bs...)
1418-
Ccolptr = Vector{indextypeC}(A.n + 1)
1419-
Crowval = Vector{indextypeC}(maxnnzC)
1420-
Cnzval = Vector{entrytypeC}(maxnnzC)
1421-
C = SparseMatrixCSC(A.m, A.n, Ccolptr, Crowval, Cnzval)
1422-
return fpreszeros ? _map_zeropres!(f, C, A, Bs...) :
1423-
_map_notzeropres!(f, fofzeros, C, A, Bs...)
1420+
if isleaftype(entrytypeC)
1421+
indextypeC = _promote_indtype(A, Bs...)
1422+
fofzeros = f(_zeros_eltypes(A, Bs...)...)
1423+
fpreszeros = fofzeros == zero(fofzeros)
1424+
maxnnzC = fpreszeros ? min(length(A), _sumnnzs(A, Bs...)) : length(A)
1425+
Ccolptr = Vector{indextypeC}(A.n + 1)
1426+
Crowval = Vector{indextypeC}(maxnnzC)
1427+
Cnzval = Vector{entrytypeC}(maxnnzC)
1428+
C = SparseMatrixCSC(A.m, A.n, Ccolptr, Crowval, Cnzval)
1429+
return fpreszeros ? _map_zeropres!(f, C, A, Bs...) :
1430+
_map_notzeropres!(f, fofzeros, C, A, Bs...)
1431+
end
1432+
return sparse(collect(Base.Generator(f, A, Bs...)))
14241433
end
14251434
# broadcast/broadcast! entry points
1426-
broadcast{Tf}(f::Tf, A::SparseMatrixCSC) = map(f, A)
1427-
broadcast!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC) = map!(f, C, A)
1428-
function broadcast!{Tf,N}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
1429-
_aresameshape(C, A, Bs...) && return map!(f, C, A, Bs...) # could avoid a second dims check in map
1435+
broadcast{F}(f::F, A::SparseMatrixCSC) = map(f, A)
1436+
broadcast!{F}(f::F, C::SparseMatrixCSC, A::SparseMatrixCSC) = map!(f, C, A)
1437+
function broadcast!{F,N}(f::F, C::SparseMatrixCSC, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
1438+
_aresameshape(C, A, Bs...) && return _map!(f, C, A, Bs...)
14301439
Base.Broadcast.check_broadcast_indices(indices(C), A, Bs...)
14311440
fofzeros = f(_zeros_eltypes(A, Bs...)...)
14321441
fpreszeros = fofzeros == zero(fofzeros)
14331442
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
14341443
_broadcast_notzeropres!(f, fofzeros, C, A, Bs...)
14351444
end
1436-
function broadcast{Tf,N}(f::Tf, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
1437-
_aresameshape(A, Bs...) && return map(f, A, Bs...) # could avoid a second dims check in map
1438-
fofzeros = f(_zeros_eltypes(A, Bs...)...)
1439-
fpreszeros = fofzeros == zero(fofzeros)
1440-
indextypeC = _promote_indtype(A, Bs...)
1445+
function broadcast{F,N}(f::F, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
1446+
_aresameshape(A, Bs...) && return _map(f, A, Bs...)
14411447
entrytypeC = _broadcast_type(f, A, Bs...)
1442-
Cm, Cn = Base.to_shape(Base.Broadcast.broadcast_indices(A, Bs...))
1443-
maxnnzC = fpreszeros ? _checked_maxnnzbcres(Cm, Cn, A, Bs...) : (Cm * Cn)
1444-
Ccolptr = Vector{indextypeC}(Cn + 1)
1445-
Crowval = Vector{indextypeC}(maxnnzC)
1446-
Cnzval = Vector{entrytypeC}(maxnnzC)
1447-
C = SparseMatrixCSC(Cm, Cn, Ccolptr, Crowval, Cnzval)
1448-
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
1449-
_broadcast_notzeropres!(f, fofzeros, C, A, Bs...)
1448+
shape = Base.Broadcast.broadcast_indices(A, Bs...)
1449+
if isleaftype(entrytypeC)
1450+
indextypeC = _promote_indtype(A, Bs...)
1451+
fofzeros = f(_zeros_eltypes(A, Bs...)...)
1452+
fpreszeros = fofzeros == zero(fofzeros)
1453+
Cm, Cn = Base.to_shape(shape)
1454+
maxnnzC = fpreszeros ? _checked_maxnnzbcres(Cm, Cn, A, Bs...) : (Cm * Cn)
1455+
Ccolptr = Vector{indextypeC}(Cn + 1)
1456+
Crowval = Vector{indextypeC}(maxnnzC)
1457+
Cnzval = Vector{entrytypeC}(maxnnzC)
1458+
C = SparseMatrixCSC(Cm, Cn, Ccolptr, Crowval, Cnzval)
1459+
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
1460+
_broadcast_notzeropres!(f, fofzeros, C, A, Bs...)
1461+
end
1462+
return sparse(Base.Broadcast.broadcast_t(f, Any, shape, CartesianRange(shape), A, Bs...))
14501463
end
14511464
# map/map! and broadcast/broadcast! entry point helper functions
14521465
@inline _sumnnzs(A) = nnz(A)
@@ -1468,7 +1481,7 @@ _broadcast_type(f, As...) = Base._promote_op(f, Base.Broadcast.typestuple(As...)
14681481

14691482
# _map_zeropres!/_map_notzeropres! specialized for a single sparse matrix
14701483
"Stores only the nonzero entries of `map(f, Matrix(A))` in `C`."
1471-
function _map_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC)
1484+
function _map_zeropres!(f, C::SparseMatrixCSC, A::SparseMatrixCSC)
14721485
spaceC = min(length(C.rowval), length(C.nzval))
14731486
Ck = 1
14741487
@inbounds for j in 1:C.n
@@ -1491,7 +1504,7 @@ end
14911504
Densifies `C`, storing `fillvalue` in place of each unstored entry in `A` and
14921505
`f(A[i,j])` in place of each stored entry `A[i,j]` in `A`.
14931506
"""
1494-
function _map_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseMatrixCSC, A::SparseMatrixCSC)
1507+
function _map_notzeropres!(f, fillvalue, C::SparseMatrixCSC, A::SparseMatrixCSC)
14951508
# Build dense matrix structure in C, expanding storage if necessary
14961509
_densestructure!(C)
14971510
# Populate values
@@ -2281,14 +2294,13 @@ round{To}(::Type{To}, A::SparseMatrixCSC) = round.(To, A)
22812294
# TODO: These seven functions should probably be reimplemented in terms of sparse map
22822295
# when a better sparse map exists. (And vectorized min, max, &, |, and xor should be
22832296
# deprecated in favor of compact-broadcast syntax.)
2284-
_checksameshape(A, B) = size(A) == size(B) || throw(DimensionMismatch("size(A) must match size(B)"))
2285-
(+)(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(+, A, B))
2286-
(-)(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(-, A, B))
2287-
min(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(min, A, B))
2288-
max(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(max, A, B))
2289-
(&)(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(&, A, B))
2290-
(|)(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(|, A, B))
2291-
xor(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(xor, A, B))
2297+
(+)(A::SparseMatrixCSC, B::SparseMatrixCSC) = broadcast(+, A, B)
2298+
(-)(A::SparseMatrixCSC, B::SparseMatrixCSC) = broadcast(-, A, B)
2299+
min(A::SparseMatrixCSC, B::SparseMatrixCSC) = broadcast(min, A, B)
2300+
max(A::SparseMatrixCSC, B::SparseMatrixCSC) = broadcast(max, A, B)
2301+
(&)(A::SparseMatrixCSC, B::SparseMatrixCSC) = broadcast(&, A, B)
2302+
(|)(A::SparseMatrixCSC, B::SparseMatrixCSC) = broadcast(|, A, B)
2303+
xor(A::SparseMatrixCSC, B::SparseMatrixCSC) = broadcast(xor, A, B)
22922304

22932305
(.+)(A::SparseMatrixCSC, B::Number) = Array(A) .+ B
22942306
( +)(A::SparseMatrixCSC, B::Array ) = Array(A) + B

0 commit comments

Comments
 (0)