@@ -5,8 +5,6 @@ module HigherOrderFns
5
5
# This module provides higher order functions specialized for sparse arrays,
6
6
# particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present.
7
7
import Base: map, map!, broadcast, broadcast!
8
- import Base. Broadcast: _containertype, promote_containertype,
9
- broadcast_indices, broadcast_c, broadcast_c!
10
8
11
9
using Base: front, tail, to_shape
12
10
using .. SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector,
@@ -23,10 +21,10 @@ using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector,
23
21
# (7) Define _broadcast_[not]zeropres! specialized for a single (input) sparse vector/matrix.
24
22
# (8) Define _broadcast_[not]zeropres! specialized for a pair of (input) sparse vectors/matrices.
25
23
# (9) Define general _broadcast_[not]zeropres! capable of handling >2 (input) sparse vectors/matrices.
26
- # (10) Define ( broadcast[!]) methods handling combinations of broadcast scalars and sparse vectors/matrices.
27
- # (11) Define ( broadcast[!]) methods handling combinations of scalars, sparse vectors/matrices,
24
+ # (10) Define broadcast methods handling combinations of broadcast scalars and sparse vectors/matrices.
25
+ # (11) Define broadcast[!] methods handling combinations of scalars, sparse vectors/matrices,
28
26
# structured matrices, and one- and two-dimensional Arrays.
29
- # (12) Define ( map[!]) methods handling combinations of sparse and structured matrices.
27
+ # (12) Define map[!] methods handling combinations of sparse and structured matrices.
30
28
31
29
32
30
# (1) The definitions below provide a common interface to sparse vectors and matrices
@@ -85,7 +83,7 @@ function _noshapecheck_map(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N
85
83
fofzeros = f (_zeros_eltypes (A, Bs... )... )
86
84
fpreszeros = _iszero (fofzeros)
87
85
maxnnzC = fpreszeros ? min (length (A), _sumnnzs (A, Bs... )) : length (A)
88
- entrytypeC = Base. Broadcast. _broadcast_eltype (f, A, Bs... )
86
+ entrytypeC = Base. Broadcast. combine_eltypes (f, A, Bs... )
89
87
indextypeC = _promote_indtype (A, Bs... )
90
88
C = _allocres (size (A), indextypeC, entrytypeC, maxnnzC)
91
89
return fpreszeros ? _map_zeropres! (f, C, A, Bs... ) :
@@ -126,8 +124,8 @@ function _diffshape_broadcast(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMa
126
124
fofzeros = f (_zeros_eltypes (A, Bs... )... )
127
125
fpreszeros = _iszero (fofzeros)
128
126
indextypeC = _promote_indtype (A, Bs... )
129
- entrytypeC = Base. Broadcast. _broadcast_eltype (f, A, Bs... )
130
- shapeC = to_shape (Base. Broadcast. broadcast_indices (A, Bs... ))
127
+ entrytypeC = Base. Broadcast. combine_eltypes (f, A, Bs... )
128
+ shapeC = to_shape (Base. Broadcast. combine_indices (A, Bs... ))
131
129
maxnnzC = fpreszeros ? _checked_maxnnzbcres (shapeC, A, Bs... ) : _densennz (shapeC)
132
130
C = _allocres (shapeC, indextypeC, entrytypeC, maxnnzC)
133
131
return fpreszeros ? _broadcast_zeropres! (f, C, A, Bs... ) :
@@ -897,29 +895,25 @@ end
897
895
end
898
896
899
897
900
- # (10) broadcast[!] over combinations of broadcast scalars and sparse vectors/matrices
898
+ # (10) broadcast over combinations of broadcast scalars and sparse vectors/matrices
901
899
902
- # broadcast shape promotion for combinations of sparse arrays and other types
903
- broadcast_indices (:: Type{AbstractSparseArray} , A) = indices (A)
904
900
# broadcast container type promotion for combinations of sparse arrays and other types
905
- _containertype ( :: Type{<:SparseVecOrMat} ) = AbstractSparseArray
906
- # combinations of sparse arrays with broadcast scalars should yield sparse arrays
907
- promote_containertype ( :: Type{Any} , :: Type{AbstractSparseArray} ) = AbstractSparseArray
908
- promote_containertype (:: Type{AbstractSparseArray } , :: Type{Any} ) = AbstractSparseArray
909
- # combinations of sparse arrays with tuples should divert to the generic AbstractArray broadcast code
910
- # (we handle combinations involving dense vectors/matrices below)
911
- promote_containertype ( :: Type{Tuple} , :: Type{AbstractSparseArray} ) = Array
912
- promote_containertype (:: Type{AbstractSparseArray } , :: Type{Tuple } ) = Array
901
+ # inference has a hard time with Union type returns, so define a specific "signal" type
902
+ abstract type SPVM end
903
+ # Because it's not a subtype of AbstractArray, we have to define Broadcast.indices
904
+ Broadcast . indices (:: Type{SPVM } , A ) = Base . indices (A)
905
+ Broadcast . rule ( :: Type{<:SparseVector} ) = SPVM
906
+ Broadcast . rule ( :: Type{<:SparseMatrixCSC} ) = SPVM
907
+ # Scalars lose to SPVM
908
+ Broadcast . rule (:: Type{SPVM } , :: Type{Broadcast.Scalar } ) = SPVM
913
909
914
- # broadcast[!] entry points for combinations of sparse arrays and other (scalar) types
915
- @inline function broadcast_c (f, :: Type{AbstractSparseArray } , mixedargs:: Vararg{Any,N} ) where N
910
+ # broadcast entry points for combinations of sparse arrays and other (scalar) types
911
+ function broadcast (f, r :: Broadcast.Result{SPVM,Void,Void } , mixedargs:: Vararg{Any,N} ) where N
916
912
parevalf, passedargstup = capturescalars (f, mixedargs)
917
913
return broadcast (parevalf, passedargstup... )
918
914
end
919
- @inline function broadcast_c! (f, :: Type{AbstractSparseArray} , dest:: SparseVecOrMat , mixedsrcargs:: Vararg{Any,N} ) where N
920
- parevalf, passedsrcargstup = capturescalars (f, mixedsrcargs)
921
- return broadcast! (parevalf, dest, passedsrcargstup... )
922
- end
915
+ # for broadcast! see (11)
916
+
923
917
# capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and
924
918
# broadcast scalar arguments (mixedargs), and returns a function (parevalf, i.e. partially
925
919
# evaluated f) and a reduced argument tuple (passedargstup) containing only the sparse
@@ -969,99 +963,70 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f(
969
963
# for combinations involving only scalars, sparse arrays, structured matrices, and dense
970
964
# vectors/matrices, promote all structured matrices and dense vectors/matrices to sparse
971
965
# and rebroadcast. otherwise, divert to generic AbstractArray broadcast code.
972
- #
973
- # this requires three steps: segregate combinations to promote to sparse via Broadcast's
974
- # containertype promotion and dispatch layer (broadcast_c[!], containertype,
975
- # promote_containertype), separate ambiguous cases from the preceding dispatch
976
- # layer in sparse broadcast's internal containertype promotion and dispatch layer
977
- # (spbroadcast_c[!], spcontainertype, promote_spcontainertype), and then promote
978
- # arguments to sparse as appropriate and rebroadcast.
979
-
980
966
981
- # first (Broadcast containertype) dispatch layer's promotion logic
982
- struct PromoteToSparse end
967
+ # combinations of sparse arrays, tuples, and arrays of dimensionality 0-2 should yield sparse arrays
968
+ abstract type PromoteToSparse end
969
+ # Since we're not making PromoteToSparse a subtype of AbstractArray, we need to define indices
970
+ Broadcast. indices (:: Type{PromoteToSparse} , A) = Base. indices (A)
983
971
984
- # broadcast containertype definitions for structured matrices
985
972
StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
986
- _containertype (:: Type{<:StructuredMatrix} ) = PromoteToSparse
987
- broadcast_indices (:: Type{PromoteToSparse} , A) = indices (A)
973
+ Broadcast. rule (:: Type{<:StructuredMatrix} ) = PromoteToSparse
988
974
989
- # combinations explicitly involving Tuples and PromoteToSparse collections
990
- # divert to the generic AbstractArray broadcast code
991
- promote_containertype (:: Type{PromoteToSparse } , :: Type{Tuple } ) = Array
992
- promote_containertype (:: Type{Tuple } , :: Type{PromoteToSparse } ) = Array
993
- # combinations involving scalars and PromoteToSparse collections continue in the promote-to-sparse funnel
994
- promote_containertype (:: Type{PromoteToSparse} , :: Type{Any } ) = PromoteToSparse
995
- promote_containertype (:: Type{Any } , :: Type{PromoteToSparse } ) = PromoteToSparse
996
- # combinations involving sparse arrays and PromoteToSparse collections continue in the promote-to-sparse funnel
997
- promote_containertype ( :: Type{PromoteToSparse} , :: Type{AbstractSparseArray} ) = PromoteToSparse
998
- promote_containertype ( :: Type{AbstractSparseArray} , :: Type{PromoteToSparse} ) = PromoteToSparse
999
- # combinations involving Arrays and PromoteToSparse collections continue in the promote-to-sparse funnel
1000
- promote_containertype (:: Type{PromoteToSparse } , :: Type{Array} ) = PromoteToSparse
1001
- promote_containertype (:: Type{Array } , :: Type{PromoteToSparse} ) = PromoteToSparse
1002
- # combinations involving Arrays and sparse arrays continue in the promote-to-sparse funnel
1003
- promote_containertype (:: Type{AbstractSparseArray } , :: Type{Array } ) = PromoteToSparse
1004
- promote_containertype (:: Type{Array } , :: Type{AbstractSparseArray } ) = PromoteToSparse
975
+ Broadcast . rule ( :: Type{SPVM} , :: Type{Broadcast.Bottom0d} ) = PromoteToSparse
976
+ Broadcast . rule ( :: Type{SPVM} , :: Type{Broadcast.BottomVector} ) = PromoteToSparse
977
+ Broadcast . rule (:: Type{SPVM } , :: Type{Broadcast.BottomMatrix } ) = PromoteToSparse
978
+ Broadcast . rule (:: Type{PromoteToSparse } , :: Type{Broadcast.Scalar } ) = PromoteToSparse
979
+ Broadcast . rule ( :: Type{ PromoteToSparse} , :: Type{Broadcast.Bottom0d} ) = PromoteToSparse
980
+ Broadcast . rule (:: Type{PromoteToSparse} , :: Type{Broadcast.BottomVector } ) = PromoteToSparse
981
+ Broadcast . rule (:: Type{PromoteToSparse } , :: Type{Broadcast.BottomMatrix } ) = PromoteToSparse
982
+ Broadcast . rule ( :: Type{PromoteToSparse} , :: Type{SPVM} ) = PromoteToSparse
983
+ # Combinations of sparse arrays and higher-dimensional arrays default to the generic infrastructure.
984
+ # In this case it's best to keep the same argument order as above, so that we don't get ambiguities
985
+ # or conflicting rules
986
+ Broadcast . rule (:: Type{SPVM } , :: Type{Broadcast.BottomArray{N}} ) where N = Broadcast . BottomArray{N}
987
+ Broadcast . rule (:: Type{PromoteToSparse } , :: Type{Broadcast.BottomArray{N}} ) where N = Broadcast . BottomArray{N}
988
+ # Tuples lead to dense outputs
989
+ Broadcast . rule (:: Type{SPVM } , :: Type{Tuple } ) = Broadcast . BottomArray{ 2 }
990
+ Broadcast . rule (:: Type{PromoteToSparse } , :: Type{Tuple } ) = Broadcast . BottomArray{ 2 }
1005
991
1006
- # second (internal sparse broadcast containertype) dispatch layer's promotion logic
1007
- # mostly just disambiguates Array from the main containertype promotion mechanism
1008
- # AbstractArray serves as a marker to shunt to the generic AbstractArray broadcast code
1009
- _spcontainertype (x) = _containertype (x)
1010
- _spcontainertype (:: Type{<:Vector} ) = Vector
1011
- _spcontainertype (:: Type{<:Matrix} ) = Matrix
1012
- _spcontainertype (:: Type{<:RowVector} ) = Matrix
1013
- _spcontainertype (:: Type{<:Ref} ) = AbstractArray
1014
- _spcontainertype (:: Type{<:AbstractArray} ) = AbstractArray
1015
- # need the following two methods to override the immediately preceding method
1016
- _spcontainertype (:: Type{<:StructuredMatrix} ) = PromoteToSparse
1017
- _spcontainertype (:: Type{<:SparseVecOrMat} ) = AbstractSparseArray
1018
- spcontainertype (x) = _spcontainertype (typeof (x))
1019
- spcontainertype (ct1, ct2) = promote_spcontainertype (spcontainertype (ct1), spcontainertype (ct2))
1020
- @inline spcontainertype (ct1, ct2, cts... ) = promote_spcontainertype (spcontainertype (ct1), spcontainertype (ct2, cts... ))
1021
-
1022
- promote_spcontainertype (:: Type{T} , :: Type{T} ) where {T} = T
1023
- # combinations involving AbstractArrays and/or Tuples divert to the generic AbstractArray broadcast code
1024
- DivertToAbsArrayBC = Union{Type{AbstractArray},Type{Tuple}}
1025
- promote_spcontainertype (:: DivertToAbsArrayBC , ct) = AbstractArray
1026
- promote_spcontainertype (ct, :: DivertToAbsArrayBC ) = AbstractArray
1027
- promote_spcontainertype (:: DivertToAbsArrayBC , :: DivertToAbsArrayBC ) = AbstractArray
1028
- # combinations involving scalars, sparse arrays, structured matrices (PromoteToSparse),
1029
- # dense vectors/matrices, and PromoteToSparse collections continue in the promote-to-sparse funnel
1030
- FunnelToSparseBC = Union{Type{Any},Type{Vector},Type{Matrix},Type{PromoteToSparse},Type{AbstractSparseArray}}
1031
- promote_spcontainertype (:: FunnelToSparseBC , :: FunnelToSparseBC ) = PromoteToSparse
1032
992
993
+ broadcast (f, r:: Broadcast.Result{PromoteToSparse,Void,Void} , As:: Vararg{Any,N} ) where {N} =
994
+ broadcast (f, map (_sparsifystructured, As)... )
1033
995
1034
- # first (Broadcast containertype) dispatch layer
1035
- # (broadcast_c[!], containertype, promote_containertype)
1036
- @inline broadcast_c (f, :: Type{PromoteToSparse} , As:: Vararg{Any,N} ) where {N} =
1037
- spbroadcast_c (f, spcontainertype (As... ), As... )
1038
- @inline broadcast_c! (f, :: Type{AbstractSparseArray} , :: Type{PromoteToSparse} , C, B, As:: Vararg{Any,N} ) where {N} =
1039
- spbroadcast_c! (f, AbstractSparseArray, spcontainertype (B, As... ), C, B, As... )
1040
- # where destination C is not an AbstractSparseArray, divert to generic AbstractArray broadcast code
1041
- @inline broadcast_c! (f, CT:: Type , :: Type{PromoteToSparse} , C, B, As:: Vararg{Any,N} ) where {N} =
1042
- broadcast_c! (f, CT, Array, C, B, As... )
996
+ # ambiguity resolution
997
+ broadcast! (:: typeof (identity), dest:: SparseVecOrMat , x:: Number ) =
998
+ fill! (dest, x)
999
+ broadcast! (f, dest:: SparseVecOrMat , x:: Number... ) =
1000
+ spbroadcast_args! (f, dest, SPVM, mixedsrcargs... )
1043
1001
1044
- # second (internal sparse broadcast containertype) dispatch layer
1045
- # (spbroadcast_c[!], spcontainertype, promote_spcontainertype)
1046
- @inline spbroadcast_c (f, :: Type{PromoteToSparse} , As:: Vararg{Any,N} ) where {N} =
1047
- broadcast (f, map (_sparsifystructured, As)... )
1048
- @inline spbroadcast_c (f, :: Type{AbstractArray} , As:: Vararg{Any,N} ) where {N} =
1049
- broadcast_c (f, Array, As... )
1050
- @inline spbroadcast_c! (f, :: Type{AbstractSparseArray} , :: Type{PromoteToSparse} , C, B, As:: Vararg{Any,N} ) where {N} =
1051
- broadcast! (f, C, _sparsifystructured (B), map (_sparsifystructured, As)... )
1052
- @inline spbroadcast_c! (f, :: Type{AbstractSparseArray} , :: Type{AbstractArray} , C, B, As:: Vararg{Any,N} ) where {N} =
1053
- broadcast_c! (f, Array, Array, C, B, As... )
1002
+ # For broadcast! with ::Any inputs, we need a layer of indirection to determine whether
1003
+ # the inputs can be promoted to SparseVecOrMat. If it's just SparseVecOrMat and scalars,
1004
+ # we can handle it here, otherwise see below for the promotion machinery.
1005
+ broadcast! (f, dest:: SparseVecOrMat , mixedsrcargs:: Vararg{Any,N} ) where N =
1006
+ spbroadcast_args! (f, dest, Broadcast. combine_types (mixedsrcargs... ), mixedsrcargs... )
1007
+ function spbroadcast_args! (f, dest, :: Type{SPVM} , mixedsrcargs:: Vararg{Any,N} ) where N
1008
+ # mixedsrcargs contains nothing but SparseVecOrMat and scalars
1009
+ parevalf, passedsrcargstup = capturescalars (f, mixedsrcargs)
1010
+ return broadcast! (parevalf, dest, passedsrcargstup... )
1011
+ end
1012
+ function spbroadcast_args! (f, dest, :: Type{PromoteToSparse} , mixedsrcargs:: Vararg{Any,N} ) where N
1013
+ broadcast! (f, dest, map (_sparsifystructured, mixedsrcargs)... )
1014
+ end
1015
+ function spbroadcast_args! (f, dest, :: Type , mixedsrcargs:: Vararg{Any,N} ) where N
1016
+ # Fallback. From a performance perspective would it be best to densify?
1017
+ Broadcast. _broadcast! (f, dest, mixedsrcargs... )
1018
+ end
1054
1019
1055
- @inline _sparsifystructured (M:: AbstractMatrix ) = SparseMatrixCSC (M)
1056
- @inline _sparsifystructured (V:: AbstractVector ) = SparseVector (V)
1057
- @inline _sparsifystructured (M:: AbstractSparseMatrix ) = SparseMatrixCSC (M)
1058
- @inline _sparsifystructured (V:: AbstractSparseVector ) = SparseVector (V)
1059
- @inline _sparsifystructured (S:: SparseVecOrMat ) = S
1060
- @inline _sparsifystructured (x) = x
1020
+ _sparsifystructured (M:: AbstractMatrix ) = SparseMatrixCSC (M)
1021
+ _sparsifystructured (V:: AbstractVector ) = SparseVector (V)
1022
+ _sparsifystructured (P:: AbstractArray{T,0} ) where T = SparseVector (reshape (P, 1 ))
1023
+ _sparsifystructured (M:: AbstractSparseMatrix ) = SparseMatrixCSC (M)
1024
+ _sparsifystructured (V:: AbstractSparseVector ) = SparseVector (V)
1025
+ _sparsifystructured (S:: SparseVecOrMat ) = S
1026
+ _sparsifystructured (x) = x
1061
1027
1062
1028
1063
1029
# (12) map[!] over combinations of sparse and structured matrices
1064
- StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
1065
1030
SparseOrStructuredMatrix = Union{SparseMatrixCSC,StructuredMatrix}
1066
1031
map (f:: Tf , A:: StructuredMatrix ) where {Tf} = _noshapecheck_map (f, _sparsifystructured (A))
1067
1032
map (f:: Tf , A:: SparseOrStructuredMatrix , Bs:: Vararg{SparseOrStructuredMatrix,N} ) where {Tf,N} =
0 commit comments