Skip to content

Commit 432dcb9

Browse files
committed
Reviewed
1. Make `broadcast` extendable 2. Fix empty case.
1 parent df49828 commit 432dcb9

File tree

3 files changed

+46
-11
lines changed

3 files changed

+46
-11
lines changed

src/broadcast.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,19 +97,28 @@ end
9797
scalar_getindex(x) = x
9898
scalar_getindex(x::Ref) = x[]
9999

100-
@generated function _broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
101-
first_staticarray = a[findfirst(ai -> ai <: Union{StaticArray, Transpose{<:Any, <:StaticArray}, Adjoint{<:Any, <:StaticArray}, Diagonal{<:Any, <:StaticArray}}, a)]
100+
isstatic(::StaticArray) = true
101+
isstatic(::Transpose{<:Any, <:StaticArray}) = true
102+
isstatic(::Adjoint{<:Any, <:StaticArray}) = true
103+
isstatic(::Diagonal{<:Any, <:StaticArray}) = true
104+
isstatic(_) = false
102105

106+
@inline first_statictype(x, y...) = isstatic(x) ? typeof(x) : first_statictype(y...)
107+
first_statictype() = error("unresolved dest type")
108+
109+
@inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
110+
first_staticarray = first_statictype(a...)
103111
if prod(newsize) == 0
104112
# Use inference to get eltype in empty case (see also comments in _map)
105-
eltys = [:(eltype(a[$i])) for i 1:length(a)]
106-
return quote
107-
@_inline_meta
108-
T = Core.Compiler.return_type(f, Tuple{$(eltys...)})
109-
@inbounds return similar_type($first_staticarray, T, Size(newsize))()
110-
end
113+
eltys = Tuple{map(eltype, a)...}
114+
T = Core.Compiler.return_type(f, eltys)
115+
@inbounds return similar_type(first_staticarray, T, Size(newsize))()
111116
end
117+
elements = __broadcast(f, sz, s, a...)
118+
@inbounds return similar_type(first_staticarray, eltype(elements), Size(newsize))(elements)
119+
end
112120

121+
@generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
113122
sizes = [sz.parameters[1] for sz s.parameters]
114123
indices = CartesianIndices(newsize)
115124
exprs = similar(indices, Expr)
@@ -123,8 +132,7 @@ scalar_getindex(x::Ref) = x[]
123132

124133
return quote
125134
@_inline_meta
126-
@inbounds elements = tuple($(exprs...))
127-
@inbounds return similar_type($first_staticarray, eltype(elements), Size(newsize))(elements)
135+
@inbounds return elements = tuple($(exprs...))
128136
end
129137
end
130138

src/precompile.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function _precompile_()
2222
end
2323

2424
# Some expensive generators
25-
@assert precompile(Tuple{typeof(which(_broadcast,(Any,Size,Tuple{Vararg{Size}},Vararg{Any},)).generator.gen),Any,Any,Any,Any,Any,Any})
25+
@assert precompile(Tuple{typeof(which(__broadcast,(Any,Size,Tuple{Vararg{Size}},Vararg{Any},)).generator.gen),Any,Any,Any,Any,Any,Any})
2626
@assert precompile(Tuple{typeof(which(_zeros,(Size,Type{<:StaticArray},)).generator.gen),Any,Any,Any,Type,Any})
2727
@assert precompile(Tuple{typeof(which(combine_sizes,(Tuple{Vararg{Size}},)).generator.gen),Any,Any})
2828
@assert precompile(Tuple{typeof(which(_mapfoldl,(Any,Any,Colon,Any,Size,Vararg{StaticArray},)).generator.gen),Any,Any,Any,Any,Any,Any,Any,Any})

test/broadcast.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,3 +282,30 @@ end
282282
end
283283

284284
end
285+
286+
# A help struct to test style-based broadcast dispatch with unknown array wrapper.
287+
# `WrapArray(A)` behaves like `A` during broadcast. But its not a `StaticArray`.
288+
struct WrapArray{T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N}
289+
data::P
290+
end
291+
Base.@propagate_inbounds Base.getindex(A::WrapArray, i::Integer...) = A.data[i...]
292+
Base.@propagate_inbounds Base.setindex!(A::WrapArray, v::Any, i::Integer...) = setindex!(A.data, v, i...)
293+
Base.size(A::WrapArray) = size(A.data)
294+
Broadcast.BroadcastStyle(::Type{WrapArray{T,N,P}}) where {T,N,P} = Broadcast.BroadcastStyle(P)
295+
StaticArrays.isstatic(A::WrapArray) = StaticArrays.isstatic(A.data)
296+
StaticArrays.Size(::Type{WrapArray{T,N,P}}) where {T,N,P} = StaticArrays.Size(P)
297+
function StaticArrays.similar_type(::Type{WrapArray{T,N,P}}, ::Type{t}, s::Size{S}) where {T,N,P,t,S}
298+
return StaticArrays.similar_type(P, t, s)
299+
end
300+
301+
@testset "Broadcast with unknown wrapper" begin
302+
data = (1, 2)
303+
for T in (SVector{2}, MVector{2})
304+
a = T(data)
305+
b = WrapArray(a)
306+
@test @inferred(b .+ a) isa T
307+
@test @inferred(b .+ b) isa T
308+
@test @inferred(b .+ (1, 2)) isa T
309+
@test b .+ a == b .+ b == a .+ a
310+
end
311+
end

0 commit comments

Comments
 (0)