Skip to content

Commit f875cf9

Browse files
committed
Restrict sparse broadcast promotion to Array
This should be reverted someday
1 parent 8bb6764 commit f875cf9

File tree

2 files changed

+50
-4
lines changed

2 files changed

+50
-4
lines changed

base/broadcast.jl

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,32 @@ BroadcastStyle(a::AbstractArrayStyle{N}, ::DefaultArrayStyle{N}) where N = a
162162
BroadcastStyle(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
163163
typeof(a)(_max(Val(M),Val(N)))
164164

165+
# FIXME
166+
# The following definitions are necessary to limit SparseArray broadcasting to "plain Arrays"
167+
# (see https://github.com/JuliaLang/julia/pull/23939#pullrequestreview-72075382).
168+
# They should be deleted once the sparse broadcast infrastucture is capable of handling
169+
# arbitrary AbstractArrays.
170+
struct VectorStyle <: AbstractArrayStyle{1} end
171+
struct MatrixStyle <: AbstractArrayStyle{2} end
172+
const VMStyle = Union{VectorStyle,MatrixStyle}
173+
# These lose to DefaultArrayStyle
174+
VectorStyle(::Val{N}) where N = DefaultArrayStyle{N}()
175+
MatrixStyle(::Val{N}) where N = DefaultArrayStyle{N}()
176+
177+
BroadcastStyle(::Type{<:Vector}) = VectorStyle()
178+
BroadcastStyle(::Type{<:Matrix}) = MatrixStyle()
179+
180+
BroadcastStyle(::MatrixStyle, ::VectorStyle) = MatrixStyle()
181+
BroadcastStyle(a::AbstractArrayStyle{Any}, ::VectorStyle) = a
182+
BroadcastStyle(a::AbstractArrayStyle{Any}, ::MatrixStyle) = a
183+
BroadcastStyle(a::AbstractArrayStyle{N}, ::VectorStyle) where N = typeof(a)(_max(Val(N), Val(1)))
184+
BroadcastStyle(a::AbstractArrayStyle{N}, ::MatrixStyle) where N = typeof(a)(_max(Val(N), Val(2)))
185+
BroadcastStyle(::VectorStyle, ::DefaultArrayStyle{N}) where N = DefaultArrayStyle(_max(Val(N), Val(1)))
186+
BroadcastStyle(::MatrixStyle, ::DefaultArrayStyle{N}) where N = DefaultArrayStyle(_max(Val(N), Val(2)))
187+
# to avoid the VectorStyle(::Val) constructor we also need the following
188+
BroadcastStyle(::VectorStyle, ::MatrixStyle) = MatrixStyle()
189+
# end FIXME
190+
165191
## Allocating the output container
166192
"""
167193
broadcast_similar(f, ::BroadcastStyle, ::Type{ElType}, inds, As...)
@@ -181,6 +207,17 @@ broadcast_similar(f, ::ArrayConflict, ::Type{ElType}, inds::Indices, As...) wher
181207
broadcast_similar(f, ::ArrayConflict, ::Type{Bool}, inds::Indices, As...) =
182208
similar(BitArray, inds)
183209

210+
# FIXME: delete when we get rid of VectorStyle and MatrixStyle
211+
broadcast_similar(f, ::VectorStyle, ::Type{ElType}, inds::Indices{1}, As...) where ElType =
212+
similar(Vector{ElType}, inds)
213+
broadcast_similar(f, ::MatrixStyle, ::Type{ElType}, inds::Indices{2}, As...) where ElType =
214+
similar(Matrix{ElType}, inds)
215+
broadcast_similar(f, ::VectorStyle, ::Type{Bool}, inds::Indices{1}, As...) =
216+
similar(BitArray, inds)
217+
broadcast_similar(f, ::MatrixStyle, ::Type{Bool}, inds::Indices{2}, As...) =
218+
similar(BitArray, inds)
219+
# end FIXME
220+
184221
## Computing the result's indices. Most types probably won't need to specialize this.
185222
broadcast_indices() = ()
186223
broadcast_indices(::Type{T}) where T = ()
@@ -582,7 +619,7 @@ Nullable{Complex{Float64}}()
582619
broadcast(f, s, combine_eltypes(f, A, Bs...), combine_indices(A, Bs...),
583620
A, Bs...)
584621

585-
const NonleafHandlingTypes = Union{DefaultArrayStyle,ArrayConflict}
622+
const NonleafHandlingTypes = Union{DefaultArrayStyle,ArrayConflict,VectorStyle,MatrixStyle}
586623

587624
@inline function broadcast(f, s::NonleafHandlingTypes, ::Type{ElType}, inds::Indices, As...) where ElType
588625
if !Base._isleaftype(ElType)

base/sparse/higherorderfns.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -989,9 +989,18 @@ PromoteToSparse(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
989989
Broadcast.BroadcastStyle(::PromoteToSparse, ::SPVM) = PromoteToSparse()
990990
Broadcast.BroadcastStyle(::PromoteToSparse, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}()
991991

992-
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{0}) = PromoteToSparse()
993-
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse()
994-
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse()
992+
# FIXME: switch to DefaultArrayStyle once we can delete VectorStyle and MatrixStyle
993+
# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{0}) = PromoteToSparse()
994+
# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse()
995+
# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse()
996+
BroadcastStyle(::Type{<:Base.RowVector{T,<:Vector}}) where T = Broadcast.MatrixStyle() # RowVector not yet defined when broadcast.jl loaded
997+
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.VectorStyle) = PromoteToSparse()
998+
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.MatrixStyle) = PromoteToSparse()
999+
Broadcast.BroadcastStyle(::SparseVecStyle, ::Broadcast.DefaultArrayStyle{N}) where N =
1000+
Broadcast.DefaultArrayStyle(Broadcast._max(Val(N), Val(1)))
1001+
Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.DefaultArrayStyle{N}) where N =
1002+
Broadcast.DefaultArrayStyle(Broadcast._max(Val(N), Val(2)))
1003+
# end FIXME
9951004

9961005
broadcast(f, ::PromoteToSparse, ::Void, ::Void, As::Vararg{Any,N}) where {N} =
9971006
broadcast(f, map(_sparsifystructured, As)...)

0 commit comments

Comments
 (0)