Skip to content

Commit 569c70e

Browse files
N5N3piever
andauthored
Turn off struct broadcast by default. (#260)
* Turn off struct broadcast by default. * add doc and rename * Add more unalias test. * Also overload `instantiate` and `_axes` This makes sure `Tuple`'s axes static during struct static broadcast. * Update wording --------- Co-authored-by: Pietro Vertechi <[email protected]>
1 parent 4ea69d5 commit 569c70e

File tree

4 files changed

+115
-31
lines changed

4 files changed

+115
-31
lines changed

src/StructArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,6 @@ function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
3838
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
3939
return backend
4040
end
41+
always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true
4142

4243
end # module

src/staticarrays_support.jl

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -38,35 +38,10 @@ end
3838
# This looks costly, but the compiler should be able to optimize them away
3939
Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(replace_structarray(bc))
4040

41-
to_staticstyle(@nospecialize(x::Type)) = x
42-
to_staticstyle(::Type{StructStaticArrayStyle{N}}) where {N} = StaticArrayStyle{N}
43-
44-
"""
45-
replace_structarray(bc::Broadcasted)
46-
47-
An internal function transforms the `Broadcasted` with `StructArray` into
48-
an equivalent one without it. This is not a must if the root `BroadcastStyle`
49-
supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
50-
e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
51-
"""
52-
function replace_structarray(bc::Broadcasted{Style}) where {Style}
53-
args = replace_structarray_args(bc.args)
54-
return Broadcasted{to_staticstyle(Style)}(bc.f, args, nothing)
55-
end
56-
function replace_structarray(A::StructArray)
57-
f = Instantiator(eltype(A))
58-
args = Tuple(components(A))
59-
return Broadcasted{StaticArrayStyle{ndims(A)}}(f, args, nothing)
60-
end
61-
replace_structarray(@nospecialize(A)) = A
62-
63-
replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(tail(args))...)
64-
replace_structarray_args(::Tuple{}) = ()
65-
6641
# StaticArrayStyle has no similar defined.
6742
# Overload `Base.copy` instead.
68-
@inline function Base.copy(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M}
69-
sa = copy(convert(Broadcasted{StaticArrayStyle{M}}, bc))
43+
@inline function try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M}
44+
sa = copy(bc)
7045
ET = eltype(sa)
7146
isnonemptystructtype(ET) || return sa
7247
elements = Tuple(sa)

src/structarray.jl

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,8 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
494494
end
495495

496496
# broadcast
497-
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown
497+
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown, ArrayConflict
498+
using Base.Broadcast: combine_styles
498499

499500
struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end
500501

@@ -524,6 +525,82 @@ Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).para
524525

525526
BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}()
526527

528+
"""
529+
always_struct_broadcast(style::BroadcastStyle)
530+
531+
Check if `style` supports struct-broadcast natively, which means:
532+
1) `Base.copy` is not overloaded.
533+
2) `Base.similar` is defined.
534+
3) `Base.copyto!` supports `StructArray`s as broadcasted arguments.
535+
536+
If any of the above conditions are not met, then this function should
537+
not be overloaded.
538+
In that case, try to overload [`try_struct_copy`](@ref) to support out-of-place
539+
struct-broadcast.
540+
"""
541+
always_struct_broadcast(::Any) = false
542+
always_struct_broadcast(::DefaultArrayStyle) = true
543+
always_struct_broadcast(::ArrayConflict) = true
544+
545+
"""
546+
try_struct_copy(bc::Broadcasted)
547+
548+
Entry for non-native outplace struct-broadcast.
549+
550+
See also [`always_struct_broadcast`](@ref).
551+
"""
552+
try_struct_copy(bc::Broadcasted) = copy(bc)
553+
554+
function Base.copy(bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
555+
if always_struct_broadcast(S())
556+
return invoke(copy, Tuple{Broadcasted}, bc)
557+
else
558+
return try_struct_copy(replace_structarray(bc))
559+
end
560+
end
561+
562+
"""
563+
replace_structarray(bc::Broadcasted)
564+
565+
An internal function transforms the `Broadcasted` with `StructArray` into
566+
an equivalent one without it. This is not a must if the root `BroadcastStyle`
567+
supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
568+
e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
569+
"""
570+
function replace_structarray(bc::Broadcasted{Style}) where {Style}
571+
args = replace_structarray_args(bc.args)
572+
Style′ = parent_style(Style())
573+
return Broadcasted{Style′}(bc.f, args, bc.axes)
574+
end
575+
function replace_structarray(A::StructArray)
576+
f = Instantiator(eltype(A))
577+
args = Tuple(components(A))
578+
Style = typeof(combine_styles(args...))
579+
return Broadcasted{Style}(f, args, axes(A))
580+
end
581+
replace_structarray(@nospecialize(A)) = A
582+
583+
replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(tail(args))...)
584+
replace_structarray_args(::Tuple{}) = ()
585+
586+
parent_style(@nospecialize(x)) = typeof(x)
587+
parent_style(::StructArrayStyle{S, N}) where {S, N} = S
588+
parent_style(::StructArrayStyle{S, N}) where {N, S<:AbstractArrayStyle{N}} = S
589+
parent_style(::StructArrayStyle{S, N}) where {S<:AbstractArrayStyle{Any}, N} = S
590+
parent_style(::StructArrayStyle{S, N}) where {S<:AbstractArrayStyle, N} = typeof(S(Val(N)))
591+
592+
# `instantiate` and `_axes` might be overloaded for static axes.
593+
function Broadcast.instantiate(bc::Broadcasted{Style}) where {Style <: StructArrayStyle}
594+
Style′ = parent_style(Style())
595+
bc′ = Broadcast.instantiate(convert(Broadcasted{Style′}, bc))
596+
return convert(Broadcasted{Style}, bc′)
597+
end
598+
599+
function Broadcast._axes(bc::Broadcasted{Style}, ::Nothing) where {Style <: StructArrayStyle}
600+
Style′ = parent_style(Style())
601+
return Broadcast._axes(convert(Broadcasted{Style′}, bc), nothing)
602+
end
603+
527604
# Here we use `similar` defined for `S` to build the dest Array.
528605
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S, N, ElType}
529606
bc′ = convert(Broadcasted{S}, bc)
@@ -532,12 +609,22 @@ end
532609

533610
# Unwrapper to recover the behaviour defined by parent style.
534611
@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
535-
return copyto!(dest, convert(Broadcasted{S}, bc))
612+
bc′ = always_struct_broadcast(S()) ? convert(Broadcasted{S}, bc) : replace_structarray(bc)
613+
return copyto!(dest, bc′)
536614
end
537615

538616
@inline function Broadcast.materialize!(::StructArrayStyle{S}, dest, bc::Broadcasted) where {S}
539-
return Broadcast.materialize!(S(), dest, bc)
617+
bc′ = always_struct_broadcast(S()) ? bc : replace_structarray(bc)
618+
return Broadcast.materialize!(S(), dest, bc′)
540619
end
541620

542621
# for aliasing analysis during broadcast
622+
function Broadcast.broadcast_unalias(dest::StructArray, src::AbstractArray)
623+
if dest === src || any(Base.Fix2(===, src), components(dest))
624+
return src
625+
else
626+
return Base.unalias(dest, src)
627+
end
628+
end
629+
543630
Base.dataids(u::StructArray) = mapreduce(Base.dataids, (a, b) -> (a..., b...), values(components(u)), init=())

test/runtests.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,7 @@ for S in (1, 2, 3)
11831183
Base.setindex!(A::$MyArray, val, i::Int) = A.A[i] = val
11841184
Base.size(A::$MyArray) = Base.size(A.A)
11851185
Base.BroadcastStyle(::Type{<:$MyArray}) = Broadcast.ArrayStyle{$MyArray}()
1186+
StructArrays.always_struct_broadcast(::Broadcast.ArrayStyle{$MyArray}) = true
11861187
end
11871188
end
11881189
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{ElType}) where ElType =
@@ -1247,6 +1248,16 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
12471248
@test Base.broadcasted(+, reshape(1:2,1,1,2), s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}}
12481249
@test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}
12491250

1251+
#parent_style
1252+
@test StructArrays.parent_style(StructArrayStyle{Broadcast.DefaultArrayStyle{0},2}()) == Broadcast.DefaultArrayStyle{2}
1253+
@test StructArrays.parent_style(StructArrayStyle{Broadcast.Style{Tuple},2}()) == Broadcast.Style{Tuple}
1254+
1255+
# allocation test for overloaded `broadcast_unalias`
1256+
StructArrays.always_struct_broadcast(::Broadcast.ArrayStyle{MyArray1}) = false
1257+
f(s) = s .+= 1
1258+
f(s)
1259+
@test (@allocated f(s)) == 0
1260+
12501261
# issue #185
12511262
A = StructArray(randn(ComplexF64, 3, 3))
12521263
B = randn(ComplexF64, 3, 3)
@@ -1288,6 +1299,8 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
12881299
a = StructArray{ComplexF64}(undef, 1)
12891300
allocated(a) = @allocated a .+ 1
12901301
@test allocated(a) == 2allocated(a.re)
1302+
allocated2(a) = @allocated a .= complex.(a.im, a.re)
1303+
@test allocated2(a) == 0
12911304
end
12921305

12931306
@testset "StructStaticArray" begin
@@ -1299,7 +1312,7 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
12991312
@test (@inferred bclog(s)) isa typeof(s)
13001313
test_allocated(bclog, s)
13011314
@test abs.(s) .+ ((1,) .+ (1,2,3,4,5,6,7,8,9,10)) isa SMatrix
1302-
bc = Base.broadcasted(+, s, s);
1315+
bc = Base.broadcasted(+, s, s, ntuple(identity, 10));
13031316
bc = Base.broadcasted(+, bc, bc, s);
13041317
@test @inferred(Broadcast.axes(bc)) === axes(s)
13051318
end
@@ -1317,6 +1330,14 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
13171330
@test backend(bcmul2(sa)) === backend(sa)
13181331
@test (sa .+= 1) === sa
13191332
end
1333+
1334+
@testset "StructSparseArray" begin
1335+
a = sprand(10, 10, 0.5)
1336+
b = sprand(10, 10, 0.5)
1337+
c = StructArray{ComplexF64}((a, b))
1338+
d = identity.(c)
1339+
@test d isa SparseMatrixCSC
1340+
end
13201341
end
13211342

13221343
@testset "map" begin

0 commit comments

Comments
 (0)