@@ -4,14 +4,15 @@ module HigherOrderFns
4
4
5
5
# This module provides higher order functions specialized for sparse arrays,
6
6
# particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present.
7
- import Base: map, map!, broadcast, broadcast !
7
+ import Base: map, map!, broadcast, copy, copyto !
8
8
9
- using Base: front, tail, to_shape
9
+ using Base: TupleLL, front, tail, to_shape
10
10
using .. SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector,
11
11
AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange
12
- using Base. Broadcast: BroadcastStyle
12
+ using Base. Broadcast: BroadcastStyle, Broadcasted, flatten
13
13
14
14
# This module is organized as follows:
15
+ # (0) Define BroadcastStyle rules and convenience types for dispatch
15
16
# (1) Define a common interface to SparseVectors and SparseMatrixCSCs sufficient for
16
17
# map[!]/broadcast[!]'s purposes. The methods below are written against this interface.
17
18
# (2) Define entry points for map[!] (short children of _map_[not]zeropres!).
@@ -28,11 +29,70 @@ using Base.Broadcast: BroadcastStyle
28
29
# (12) Define map[!] methods handling combinations of sparse and structured matrices.
29
30
30
31
32
+ # (0) BroadcastStyle rules and convenience types for dispatch
33
+
34
+ SparseVecOrMat = Union{SparseVector,SparseMatrixCSC}
35
+
36
+ # broadcast container type promotion for combinations of sparse arrays and other types
37
+ struct SparseVecStyle <: Broadcast.AbstractArrayStyle{1} end
38
+ struct SparseMatStyle <: Broadcast.AbstractArrayStyle{2} end
39
+ Broadcast. BroadcastStyle (:: Type{<:SparseVector} ) = SparseVecStyle ()
40
+ Broadcast. BroadcastStyle (:: Type{<:SparseMatrixCSC} ) = SparseMatStyle ()
41
+ const SPVM = Union{SparseVecStyle,SparseMatStyle}
42
+
43
+ # SparseVecStyle handles 0-1 dimensions, SparseMatStyle 0-2 dimensions.
44
+ # SparseVecStyle promotes to SparseMatStyle for 2 dimensions.
45
+ # Fall back to DefaultArrayStyle for higher dimensionality.
46
+ SparseVecStyle (:: Val{0} ) = SparseVecStyle ()
47
+ SparseVecStyle (:: Val{1} ) = SparseVecStyle ()
48
+ SparseVecStyle (:: Val{2} ) = SparseMatStyle ()
49
+ SparseVecStyle (:: Val{N} ) where N = Broadcast. DefaultArrayStyle {N} ()
50
+ SparseMatStyle (:: Val{0} ) = SparseMatStyle ()
51
+ SparseMatStyle (:: Val{1} ) = SparseMatStyle ()
52
+ SparseMatStyle (:: Val{2} ) = SparseMatStyle ()
53
+ SparseMatStyle (:: Val{N} ) where N = Broadcast. DefaultArrayStyle {N} ()
54
+
55
+ Broadcast. BroadcastStyle (:: SparseMatStyle , :: SparseVecStyle ) = SparseMatStyle ()
56
+
57
+ struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end
58
+ StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
59
+ Broadcast. BroadcastStyle (:: Type{<:StructuredMatrix} ) = PromoteToSparse ()
60
+
61
+ PromoteToSparse (:: Val{0} ) = PromoteToSparse ()
62
+ PromoteToSparse (:: Val{1} ) = PromoteToSparse ()
63
+ PromoteToSparse (:: Val{2} ) = PromoteToSparse ()
64
+ PromoteToSparse (:: Val{N} ) where N = Broadcast. DefaultArrayStyle {N} ()
65
+
66
+ Broadcast. BroadcastStyle (:: PromoteToSparse , :: SPVM ) = PromoteToSparse ()
67
+
68
+ # FIXME : switch to DefaultArrayStyle once we can delete VectorStyle and MatrixStyle
69
+ BroadcastStyle (:: Type{<:Base.Adjoint{T,<:Vector}} ) where T = Broadcast. MatrixStyle () # Adjoint not yet defined when broadcast.jl loaded
70
+ BroadcastStyle (:: Type{<:Base.Transpose{T,<:Vector}} ) where T = Broadcast. MatrixStyle () # Transpose not yet defined when broadcast.jl loaded
71
+ Broadcast. BroadcastStyle (:: SPVM , :: Broadcast.VectorStyle ) = PromoteToSparse ()
72
+ Broadcast. BroadcastStyle (:: SPVM , :: Broadcast.MatrixStyle ) = PromoteToSparse ()
73
+ Broadcast. BroadcastStyle (:: SparseVecStyle , :: Broadcast.DefaultArrayStyle{N} ) where N =
74
+ Broadcast. DefaultArrayStyle (Broadcast. _max (Val (N), Val (1 )))
75
+ Broadcast. BroadcastStyle (:: SparseMatStyle , :: Broadcast.DefaultArrayStyle{N} ) where N =
76
+ Broadcast. DefaultArrayStyle (Broadcast. _max (Val (N), Val (2 )))
77
+ # end FIXME
78
+
79
+ # Tuples promote to dense
80
+ Broadcast. BroadcastStyle (:: SparseVecStyle , :: Broadcast.Style{Tuple} ) = Broadcast. DefaultArrayStyle {1} ()
81
+ Broadcast. BroadcastStyle (:: SparseMatStyle , :: Broadcast.Style{Tuple} ) = Broadcast. DefaultArrayStyle {2} ()
82
+ Broadcast. BroadcastStyle (:: PromoteToSparse , :: Broadcast.Style{Tuple} ) = Broadcast. DefaultArrayStyle {2} ()
83
+
84
+ # Dispatch on broadcast operations by number of arguments
85
+ const Broadcasted0{Style<: Union{Nothing,BroadcastStyle} ,ElType,Axes,Indexing<: Union{Nothing,TupleLL{Nothing,Nothing}} ,F} =
86
+ Broadcasted{Style,ElType,Axes,Indexing,F,TupleLL{Nothing,Nothing}}
87
+ const SpBroadcasted1{Style<: SPVM ,ElType,Axes,Indexing<: Union{Nothing,TupleLL} ,F,Args<: TupleLL{<:SparseVecOrMat,Nothing} } =
88
+ Broadcasted{Style,ElType,Axes,Indexing,F,Args}
89
+ const SpBroadcasted2{Style<: SPVM ,ElType,Axes,Indexing<: Union{Nothing,TupleLL} ,F,Args<: TupleLL{<:SparseVecOrMat,TupleLL{<:SparseVecOrMat,Nothing}} } =
90
+ Broadcasted{Style,ElType,Axes,Indexing,F,Args}
91
+
31
92
# (1) The definitions below provide a common interface to sparse vectors and matrices
32
93
# sufficient for the purposes of map[!]/broadcast[!]. This interface treats sparse vectors
33
94
# as n-by-one sparse matrices which, though technically incorrect, is how broacast[!] views
34
95
# sparse vectors in practice.
35
- SparseVecOrMat = Union{SparseVector,SparseMatrixCSC}
36
96
@inline numrows (A:: SparseVector ) = A. n
37
97
@inline numrows (A:: SparseMatrixCSC ) = A. m
38
98
@inline numcols (A:: SparseVector ) = 1
@@ -91,11 +151,11 @@ function _noshapecheck_map(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N
91
151
_map_notzeropres! (f, fofzeros, C, A, Bs... )
92
152
end
93
153
# (3) broadcast[!] entry points
94
- broadcast (f:: Tf , A:: SparseVector ) where {Tf} = _noshapecheck_map (f, A)
95
- broadcast (f:: Tf , A:: SparseMatrixCSC ) where {Tf} = _noshapecheck_map (f, A)
154
+ copy (bc:: SpBroadcasted1 ) = _noshapecheck_map (bc. f, bc. args. head)
96
155
97
- @inline function broadcast! (f :: Tf , C:: SparseVecOrMat , :: Nothing ) where Tf
156
+ @inline function copyto! ( C:: SparseVecOrMat , bc :: Broadcasted0{ Nothing} )
98
157
isempty (C) && return _finishempty! (C)
158
+ f = bc. f
99
159
fofnoargs = f ()
100
160
if _iszero (fofnoargs) # f() is zero, so empty C
101
161
trimstorage! (C, 0 )
@@ -108,13 +168,6 @@ broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A)
108
168
return C
109
169
end
110
170
111
- # the following three similar defs are necessary for type stability in the mixed vector/matrix case
112
- broadcast (f:: Tf , A:: SparseVector , Bs:: Vararg{SparseVector,N} ) where {Tf,N} =
113
- _aresameshape (A, Bs... ) ? _noshapecheck_map (f, A, Bs... ) : _diffshape_broadcast (f, A, Bs... )
114
- broadcast (f:: Tf , A:: SparseMatrixCSC , Bs:: Vararg{SparseMatrixCSC,N} ) where {Tf,N} =
115
- _aresameshape (A, Bs... ) ? _noshapecheck_map (f, A, Bs... ) : _diffshape_broadcast (f, A, Bs... )
116
- broadcast (f:: Tf , A:: SparseVecOrMat , Bs:: Vararg{SparseVecOrMat,N} ) where {Tf,N} =
117
- _diffshape_broadcast (f, A, Bs... )
118
171
function _diffshape_broadcast (f:: Tf , A:: SparseVecOrMat , Bs:: Vararg{SparseVecOrMat,N} ) where {Tf,N}
119
172
fofzeros = f (_zeros_eltypes (A, Bs... )... )
120
173
fpreszeros = _iszero (fofzeros)
139
192
@inline _aresameshape (A) = true
140
193
@inline _aresameshape (A, B) = size (A) == size (B)
141
194
@inline _aresameshape (A, B, Cs... ) = _aresameshape (A, B) ? _aresameshape (B, Cs... ) : false
195
+ @inline _aresameshape (t:: TupleLL{<:Any,Nothing} ) = true
196
+ @inline _aresameshape (t:: TupleLL{<:Any,<:TupleLL} ) =
197
+ _aresameshape (t. head, t. rest. head) ? _aresameshape (t. rest) : false
142
198
@inline _checksameshape (As... ) = _aresameshape (As... ) || throw (DimensionMismatch (" argument shapes must match" ))
199
+ @inline _all_args_isa (t:: TupleLL{<:Any,Nothing} , :: Type{T} ) where T = isa (t. head, T)
200
+ @inline _all_args_isa (t:: TupleLL , :: Type{T} ) where T = isa (t. head, T) & _all_args_isa (t. rest, T)
201
+ @inline _all_args_isa (t:: TupleLL{<:Broadcasted,Nothing} , :: Type{T} ) where T = _all_args_isa (t. head. args, T)
202
+ @inline _all_args_isa (t:: TupleLL{<:Broadcasted} , :: Type{T} ) where T = _all_args_isa (t. head. args, T) & _all_args_isa (t. rest, T)
143
203
@inline _densennz (shape:: NTuple{1} ) = shape[1 ]
144
204
@inline _densennz (shape:: NTuple{2} ) = shape[1 ] * shape[2 ]
145
205
_maxnnzfrom (shape:: NTuple{1} , A) = nnz (A) * div (shape[1 ], A. n)
@@ -892,37 +952,42 @@ end
892
952
893
953
# (10) broadcast over combinations of broadcast scalars and sparse vectors/matrices
894
954
895
- # broadcast container type promotion for combinations of sparse arrays and other types
896
- struct SparseVecStyle <: Broadcast.AbstractArrayStyle{1} end
897
- struct SparseMatStyle <: Broadcast.AbstractArrayStyle{2} end
898
- Broadcast. BroadcastStyle (:: Type{<:SparseVector} ) = SparseVecStyle ()
899
- Broadcast. BroadcastStyle (:: Type{<:SparseMatrixCSC} ) = SparseMatStyle ()
900
- const SPVM = Union{SparseVecStyle,SparseMatStyle}
901
-
902
- # SparseVecStyle handles 0-1 dimensions, SparseMatStyle 0-2 dimensions.
903
- # SparseVecStyle promotes to SparseMatStyle for 2 dimensions.
904
- # Fall back to DefaultArrayStyle for higher dimensionality.
905
- SparseVecStyle (:: Val{0} ) = SparseVecStyle ()
906
- SparseVecStyle (:: Val{1} ) = SparseVecStyle ()
907
- SparseVecStyle (:: Val{2} ) = SparseMatStyle ()
908
- SparseVecStyle (:: Val{N} ) where N = Broadcast. DefaultArrayStyle {N} ()
909
- SparseMatStyle (:: Val{0} ) = SparseMatStyle ()
910
- SparseMatStyle (:: Val{1} ) = SparseMatStyle ()
911
- SparseMatStyle (:: Val{2} ) = SparseMatStyle ()
912
- SparseMatStyle (:: Val{N} ) where N = Broadcast. DefaultArrayStyle {N} ()
913
-
914
- Broadcast. BroadcastStyle (:: SparseMatStyle , :: SparseVecStyle ) = SparseMatStyle ()
915
-
916
- # Tuples promote to dense
917
- Broadcast. BroadcastStyle (:: SparseVecStyle , :: Broadcast.Style{Tuple} ) = Broadcast. DefaultArrayStyle {1} ()
918
- Broadcast. BroadcastStyle (:: SparseMatStyle , :: Broadcast.Style{Tuple} ) = Broadcast. DefaultArrayStyle {2} ()
919
-
920
955
# broadcast entry points for combinations of sparse arrays and other (scalar) types
921
- function broadcast (f, :: SPVM , :: Nothing , :: Nothing , mixedargs:: Vararg{Any,N} ) where N
922
- parevalf, passedargstup = capturescalars (f, mixedargs)
956
+ function copy (bc:: Broadcasted{<:SPVM} )
957
+ bcf = flatten (bc)
958
+ _all_args_isa (bcf. args, SparseVector) && return _shapecheckbc (bcf)
959
+ _all_args_isa (bcf. args, SparseMatrixCSC) && return _shapecheckbc (bcf)
960
+ args = Tuple (bcf. args)
961
+ _all_args_isa (bcf. args, SparseVecOrMat) && return _diffshape_broadcast (bcf. f, args... )
962
+ parevalf, passedargstup = capturescalars (bcf. f, args)
923
963
return broadcast (parevalf, passedargstup... )
924
964
end
925
- # for broadcast! see (11)
965
+ function _shapecheckbc (bc:: Broadcasted )
966
+ args = Tuple (bc. args)
967
+ _aresameshape (bc. args) ? _noshapecheck_map (bc. f, args... ) : _diffshape_broadcast (bc. f, args... )
968
+ end
969
+
970
+ function copyto! (dest:: SparseVecOrMat , bc:: Broadcasted{<:SPVM} )
971
+ if bc. f === identity && bc isa SpBroadcasted1 && Base. axes (dest) == (A = bc. args. head; Base. axes (A))
972
+ return copyto! (dest, A)
973
+ end
974
+ bcf = flatten (bc)
975
+ As = Tuple (bcf. args)
976
+ if _all_args_isa (bcf. args, SparseVecOrMat)
977
+ _aresameshape (dest, As... ) && return _noshapecheck_map! (bcf. f, dest, As... )
978
+ Base. Broadcast. check_broadcast_indices (axes (dest), As... )
979
+ fofzeros = bcf. f (_zeros_eltypes (As... )... )
980
+ fpreszeros = _iszero (fofzeros)
981
+ fpreszeros ? _broadcast_zeropres! (bcf. f, dest, As... ) :
982
+ _broadcast_notzeropres! (bcf. f, fofzeros, dest, As... )
983
+ else
984
+ # As contains nothing but SparseVecOrMat and scalars
985
+ # See below for capturescalars
986
+ parevalf, passedsrcargstup = capturescalars (bcf. f, As)
987
+ broadcast! (parevalf, dest, passedsrcargstup... )
988
+ end
989
+ return dest
990
+ end
926
991
927
992
# capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and
928
993
# broadcast scalar arguments (mixedargs), and returns a function (parevalf, i.e. partially
@@ -971,59 +1036,16 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f(
971
1036
# vectors/matrices, promote all structured matrices and dense vectors/matrices to sparse
972
1037
# and rebroadcast. otherwise, divert to generic AbstractArray broadcast code.
973
1038
974
- struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end
975
- StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
976
- Broadcast. BroadcastStyle (:: Type{<:StructuredMatrix} ) = PromoteToSparse ()
977
-
978
- PromoteToSparse (:: Val{0} ) = PromoteToSparse ()
979
- PromoteToSparse (:: Val{1} ) = PromoteToSparse ()
980
- PromoteToSparse (:: Val{2} ) = PromoteToSparse ()
981
- PromoteToSparse (:: Val{N} ) where N = Broadcast. DefaultArrayStyle {N} ()
982
-
983
- Broadcast. BroadcastStyle (:: PromoteToSparse , :: SPVM ) = PromoteToSparse ()
984
- Broadcast. BroadcastStyle (:: PromoteToSparse , :: Broadcast.Style{Tuple} ) = Broadcast. DefaultArrayStyle {2} ()
985
-
986
- # FIXME : switch to DefaultArrayStyle once we can delete VectorStyle and MatrixStyle
987
- # Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{0}) = PromoteToSparse()
988
- # Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse()
989
- # Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse()
990
- BroadcastStyle (:: Type{<:Base.Adjoint{T,<:Vector}} ) where T = Broadcast. MatrixStyle () # Adjoint not yet defined when broadcast.jl loaded
991
- BroadcastStyle (:: Type{<:Base.Transpose{T,<:Vector}} ) where T = Broadcast. MatrixStyle () # Transpose not yet defined when broadcast.jl loaded
992
- Broadcast. BroadcastStyle (:: SPVM , :: Broadcast.VectorStyle ) = PromoteToSparse ()
993
- Broadcast. BroadcastStyle (:: SPVM , :: Broadcast.MatrixStyle ) = PromoteToSparse ()
994
- Broadcast. BroadcastStyle (:: SparseVecStyle , :: Broadcast.DefaultArrayStyle{N} ) where N =
995
- Broadcast. DefaultArrayStyle (Broadcast. _max (Val (N), Val (1 )))
996
- Broadcast. BroadcastStyle (:: SparseMatStyle , :: Broadcast.DefaultArrayStyle{N} ) where N =
997
- Broadcast. DefaultArrayStyle (Broadcast. _max (Val (N), Val (2 )))
998
- # end FIXME
999
-
1000
- broadcast (f, :: PromoteToSparse , :: Nothing , :: Nothing , As:: Vararg{Any,N} ) where {N} =
1001
- broadcast (f, map (_sparsifystructured, As)... )
1002
-
1003
- # For broadcast! with ::Any inputs, we need a layer of indirection to determine whether
1004
- # the inputs can be promoted to SparseVecOrMat. If it's just SparseVecOrMat and scalars,
1005
- # we can handle it here, otherwise see below for the promotion machinery.
1006
- function broadcast! (f:: Tf , dest:: SparseVecOrMat , :: SPVM , A:: SparseVecOrMat , Bs:: Vararg{SparseVecOrMat,N} ) where {Tf,N}
1007
- if f isa typeof (identity) && N == 0 && Base. axes (dest) == Base. axes (A)
1008
- return copyto! (dest, A)
1009
- end
1010
- _aresameshape (dest, A, Bs... ) && return _noshapecheck_map! (f, dest, A, Bs... )
1011
- Base. Broadcast. check_broadcast_indices (axes (dest), A, Bs... )
1012
- fofzeros = f (_zeros_eltypes (A, Bs... )... )
1013
- fpreszeros = _iszero (fofzeros)
1014
- fpreszeros ? _broadcast_zeropres! (f, dest, A, Bs... ) :
1015
- _broadcast_notzeropres! (f, fofzeros, dest, A, Bs... )
1016
- return dest
1039
+ function copy (bc:: Broadcasted{PromoteToSparse} )
1040
+ bcf = flatten (bc)
1041
+ As = Tuple (bcf. args)
1042
+ broadcast (bcf. f, map (_sparsifystructured, As)... )
1017
1043
end
1018
- function broadcast! (f:: Tf , dest:: SparseVecOrMat , :: SPVM , mixedsrcargs:: Vararg{Any,N} ) where {Tf,N}
1019
- # mixedsrcargs contains nothing but SparseVecOrMat and scalars
1020
- parevalf, passedsrcargstup = capturescalars (f, mixedsrcargs)
1021
- broadcast! (parevalf, dest, passedsrcargstup... )
1022
- return dest
1023
- end
1024
- function broadcast! (f:: Tf , dest:: SparseVecOrMat , :: PromoteToSparse , mixedsrcargs:: Vararg{Any,N} ) where {Tf,N}
1025
- broadcast! (f, dest, map (_sparsifystructured, mixedsrcargs)... )
1026
- return dest
1044
+
1045
+ function copyto! (dest:: SparseVecOrMat , bc:: Broadcasted{PromoteToSparse} )
1046
+ bcf = flatten (bc)
1047
+ As = Tuple (bcf. args)
1048
+ broadcast! (bcf. f, dest, map (_sparsifystructured, As)... )
1027
1049
end
1028
1050
1029
1051
_sparsifystructured (M:: AbstractMatrix ) = SparseMatrixCSC (M)
0 commit comments