Skip to content

Commit 1afecf4

Browse files
authored
Generalize StructArray's broadcast. (#215)
1 parent 4056c71 commit 1afecf4

File tree

6 files changed

+232
-26
lines changed

6 files changed

+232
-26
lines changed

Project.toml

+6-3
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,22 @@ version = "0.6.13"
55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
8+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
89
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
910
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1011

1112
[compat]
1213
Adapt = "1, 2, 3"
1314
DataAPI = "1"
14-
StaticArraysCore = "1.1"
15-
StaticArrays = "1.5.4"
15+
GPUArraysCore = "0.1.2"
16+
StaticArrays = "1.5.6"
17+
StaticArraysCore = "1.3"
1618
Tables = "1"
1719
julia = "1.6"
1820

1921
[extras]
2022
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
23+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
2124
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
2225
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
2326
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -26,4 +29,4 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
2629
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"
2730

2831
[targets]
29-
test = ["Test", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"]
32+
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"]

src/StructArrays.jl

+10
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,14 @@ end
2929
import Adapt
3030
Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s)
3131

32+
# for GPU broadcast
33+
import GPUArraysCore
34+
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
35+
backends = map_params(GPUArraysCore.backend, array_types(T))
36+
backend, others = backends[1], tail(backends)
37+
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
38+
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
39+
return backend
40+
end
41+
3242
end # module

src/interface.jl

+7-1
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,10 @@ function createinstance(::Type{T}, args...) where {T}
4949
isconcretetype(T) ? bypass_constructor(T, args) : T(args...)
5050
end
5151

52-
createinstance(::Type{T}, args...) where {T<:Tup} = T(args)
52+
createinstance(::Type{T}, args...) where {T<:Tup} = T(args)
53+
54+
struct Instantiator{T} end
55+
56+
Instantiator(::Type{T}) where {T} = Instantiator{T}()
57+
58+
(::Instantiator{T})(args...) where {T} = createinstance(T, args...)

src/staticarrays_support.jl

+65-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import StaticArraysCore: StaticArray, FieldArray, tuple_prod
1+
using StaticArraysCore: StaticArray, FieldArray, tuple_prod
22

33
"""
44
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
@@ -27,3 +27,67 @@ StructArrays.component(s::StaticArray, i) = getindex(s, i)
2727
end
2828
StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i)
2929
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)
30+
31+
# Broadcast overload
32+
using StaticArraysCore: StaticArrayStyle, similar_type
33+
StructStaticArrayStyle{N} = StructArrayStyle{StaticArrayStyle{N}, N}
34+
function Broadcast.instantiate(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M}
35+
bc′ = Broadcast.instantiate(replace_structarray(bc))
36+
return convert(Broadcasted{StructStaticArrayStyle{M}}, bc′)
37+
end
38+
# This looks costly, but the compiler should be able to optimize them away
39+
Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(replace_structarray(bc))
40+
41+
to_staticstyle(@nospecialize(x::Type)) = x
42+
to_staticstyle(::Type{StructStaticArrayStyle{N}}) where {N} = StaticArrayStyle{N}
43+
44+
"""
45+
replace_structarray(bc::Broadcasted)
46+
47+
An internal function transforms the `Broadcasted` with `StructArray` into
48+
an equivalent one without it. This is not a must if the root `BroadcastStyle`
49+
supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
50+
e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
51+
"""
52+
function replace_structarray(bc::Broadcasted{Style}) where {Style}
53+
args = replace_structarray_args(bc.args)
54+
return Broadcasted{to_staticstyle(Style)}(bc.f, args, nothing)
55+
end
56+
function replace_structarray(A::StructArray)
57+
f = Instantiator(eltype(A))
58+
args = Tuple(components(A))
59+
return Broadcasted{StaticArrayStyle{ndims(A)}}(f, args, nothing)
60+
end
61+
replace_structarray(@nospecialize(A)) = A
62+
63+
replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(tail(args))...)
64+
replace_structarray_args(::Tuple{}) = ()
65+
66+
# StaticArrayStyle has no similar defined.
67+
# Overload `Base.copy` instead.
68+
@inline function Base.copy(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M}
69+
sa = copy(convert(Broadcasted{StaticArrayStyle{M}}, bc))
70+
ET = eltype(sa)
71+
isnonemptystructtype(ET) || return sa
72+
elements = Tuple(sa)
73+
@static if VERSION >= v"1.7"
74+
arrs = ntuple(Val(fieldcount(ET))) do i
75+
similar_type(sa, fieldtype(ET, i))(_getfields(elements, i))
76+
end
77+
else
78+
_fieldtype(::Type{T}) where {T} = i -> fieldtype(T, i)
79+
__fieldtype = _fieldtype(ET)
80+
arrs = ntuple(Val(fieldcount(ET))) do i
81+
similar_type(sa, __fieldtype(i))(_getfields(elements, i))
82+
end
83+
end
84+
return StructArray{ET}(arrs)
85+
end
86+
87+
@inline function _getfields(x::Tuple, i::Int)
88+
if @generated
89+
return Expr(:tuple, (:(getfield(x[$j], i)) for j in 1:fieldcount(x))...)
90+
else
91+
return map(Base.Fix2(getfield, i), x)
92+
end
93+
end

src/structarray.jl

+24-4
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
486486
end
487487

488488
# broadcast
489-
import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle
489+
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown
490490

491491
struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end
492492

@@ -496,19 +496,39 @@ function StructArrayStyle{S, M}(::Val{N}) where {S, M, N}
496496
return StructArrayStyle{T, N}()
497497
end
498498

499+
# StructArrayStyle is a wrapped style.
500+
# Here we try our best to resolve style conflict.
501+
function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S, N}) where {S, N, M}
502+
N′ = M === Any || N === Any ? Any : max(M, N)
503+
S′ = Broadcast.result_style(S(), b)
504+
return S′ isa StructArrayStyle ? typeof(S′)(Val{N′}()) : StructArrayStyle{typeof(S′), N′}()
505+
end
506+
BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown()
507+
499508
@inline combine_style_types(::Type{A}, args...) where {A<:AbstractArray} =
500509
combine_style_types(BroadcastStyle(A), args...)
501510
@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where {A<:AbstractArray} =
502511
combine_style_types(Broadcast.result_style(s, BroadcastStyle(A)), args...)
512+
combine_style_types(::StructArrayStyle{S}) where {S} = S() # avoid nested StructArrayStyle
503513
combine_style_types(s::BroadcastStyle) = s
504514

505515
Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).parameters...)
506516

507517
BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}()
508518

509-
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S<:Union{DefaultArrayStyle,StructArrayStyle}, N, ElType}
510-
ContainerType = isnonemptystructtype(ElType) ? StructArray{ElType} : Array{ElType}
511-
return similar(ContainerType, axes(bc))
519+
# Here we use `similar` defined for `S` to build the dest Array.
520+
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S, N, ElType}
521+
bc′ = convert(Broadcasted{S}, bc)
522+
return isnonemptystructtype(ElType) ? buildfromschema(T -> similar(bc′, T), ElType) : similar(bc′, ElType)
523+
end
524+
525+
# Unwrapper to recover the behaviour defined by parent style.
526+
@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
527+
return copyto!(dest, convert(Broadcasted{S}, bc))
528+
end
529+
530+
@inline function Broadcast.materialize!(::StructArrayStyle{S}, dest, bc::Broadcasted) where {S}
531+
return Broadcast.materialize!(S(), dest, bc)
512532
end
513533

514534
# for aliasing analysis during broadcast

test/runtests.jl

+120-17
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import Tables, PooledArrays, WeakRefStrings
66
using TypedTables: Table
77
using DataAPI: refarray, refvalue
88
using Adapt: adapt, Adapt
9+
using JLArrays
910
using Test
1011

1112
using Documenter: doctest
@@ -1100,17 +1101,39 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)
11001101
@test t.b.d isa Array
11011102
end
11021103

1103-
struct MyArray{T,N} <: AbstractArray{T,N}
1104-
A::Array{T,N}
1104+
# The following code defines `MyArray1/2/3` with different `BroadcastStyle`s.
1105+
# 1. `MyArray1` and `MyArray1` have `similar` defined.
1106+
# We use them to simulate `BroadcastStyle` overloading `Base.copyto!`.
1107+
# 2. `MyArray3` has no `similar` defined.
1108+
# We use it to simulate `BroadcastStyle` overloading `Base.copy`.
1109+
# 3. Their resolved style could be summaryized as (`-` means conflict)
1110+
# | MyArray1 | MyArray2 | MyArray3 | Array
1111+
# -------------------------------------------------------------
1112+
# MyArray1 | MyArray1 | - | MyArray1 | MyArray1
1113+
# MyArray2 | - | MyArray2 | - | MyArray2
1114+
# MyArray3 | MyArray1 | - | MyArray3 | MyArray3
1115+
# Array | MyArray1 | Array | MyArray3 | Array
1116+
1117+
for S in (1, 2, 3)
1118+
MyArray = Symbol(:MyArray, S)
1119+
@eval begin
1120+
struct $MyArray{T,N} <: AbstractArray{T,N}
1121+
A::Array{T,N}
1122+
end
1123+
$MyArray{T}(::UndefInitializer, sz::Dims) where T = $MyArray(Array{T}(undef, sz))
1124+
Base.IndexStyle(::Type{<:$MyArray}) = IndexLinear()
1125+
Base.getindex(A::$MyArray, i::Int) = A.A[i]
1126+
Base.setindex!(A::$MyArray, val, i::Int) = A.A[i] = val
1127+
Base.size(A::$MyArray) = Base.size(A.A)
1128+
Base.BroadcastStyle(::Type{<:$MyArray}) = Broadcast.ArrayStyle{$MyArray}()
1129+
end
11051130
end
1106-
MyArray{T}(::UndefInitializer, sz::Dims) where T = MyArray(Array{T}(undef, sz))
1107-
Base.IndexStyle(::Type{<:MyArray}) = IndexLinear()
1108-
Base.getindex(A::MyArray, i::Int) = A.A[i]
1109-
Base.setindex!(A::MyArray, val, i::Int) = A.A[i] = val
1110-
Base.size(A::MyArray) = Base.size(A.A)
1111-
Base.BroadcastStyle(::Type{<:MyArray}) = Broadcast.ArrayStyle{MyArray}()
1112-
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{ElType}) where ElType =
1113-
MyArray{ElType}(undef, size(bc))
1131+
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{ElType}) where ElType =
1132+
MyArray1{ElType}(undef, size(bc))
1133+
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}}, ::Type{ElType}) where ElType =
1134+
MyArray2{ElType}(undef, size(bc))
1135+
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyArray3}) = Broadcast.ArrayStyle{MyArray1}()
1136+
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayStyle) = S
11141137

11151138
@testset "broadcast" begin
11161139
s = StructArray{ComplexF64}((rand(2,2), rand(2,2)))
@@ -1128,19 +1151,44 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
11281151
# used inside of broadcast but we also test it here explicitly
11291152
@test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N})
11301153

1131-
s = StructArray{ComplexF64}((MyArray(rand(2)), MyArray(rand(2))))
1132-
@test_throws MethodError s .+ s
11331154

1155+
@testset "style conflict check" begin
1156+
using StructArrays: StructArrayStyle
1157+
# Make sure we can handle style with similar defined
1158+
# And we can handle most conflicts
1159+
# `s1` and `s2` have similar defined, but `s3` does not
1160+
# `s2` conflicts with `s1` and `s3` and is weaker than `DefaultArrayStyle`
1161+
s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2))))
1162+
s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2))))
1163+
s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2))))
1164+
s4 = StructArray{ComplexF64}((rand(2), rand(2)))
1165+
test_set = Any[s1, s2, s3, s4]
1166+
tested_style = Any[]
1167+
dotaddadd((a, b, c),) = @. a + b + c
1168+
for as in Iterators.product(test_set, test_set, test_set)
1169+
ares = map(a->a.re, as)
1170+
aims = map(a->a.im, as)
1171+
style = Broadcast.combine_styles(ares...)
1172+
@test Broadcast.combine_styles(as...) === StructArrayStyle{typeof(style),1}()
1173+
if !(style in tested_style)
1174+
push!(tested_style, style)
1175+
if style isa Broadcast.ArrayStyle{MyArray3}
1176+
@test_throws MethodError dotaddadd(as)
1177+
else
1178+
d = StructArray{ComplexF64}((dotaddadd(ares), dotaddadd(aims)))
1179+
@test @inferred(dotaddadd(as))::typeof(d) == d
1180+
end
1181+
end
1182+
end
1183+
@test length(tested_style) == 5
1184+
end
11341185
# test for dimensionality track
1186+
s = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2))))
11351187
@test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
11361188
@test Base.broadcasted(+, s, 1:2) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
11371189
@test Base.broadcasted(+, s, reshape(1:2,1,2)) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}}
11381190
@test Base.broadcasted(+, reshape(1:2,1,1,2), s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}}
1139-
1140-
a = StructArray([1;2+im])
1141-
b = StructArray([1;;2+im])
1142-
@test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b)
1143-
@test a .+ Any[1] isa StructArray
1191+
@test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}
11441192

11451193
# issue #185
11461194
A = StructArray(randn(ComplexF64, 3, 3))
@@ -1155,6 +1203,61 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
11551203

11561204
@test identity.(StructArray(x=StructArray(a=1:3)))::StructArray == [(x=(a=1,),), (x=(a=2,),), (x=(a=3,),)]
11571205
@test (x -> x.x.a).(StructArray(x=StructArray(a=1:3))) == [1, 2, 3]
1206+
@test identity.(StructArray(x=StructArray(x=StructArray(a=1:3))))::StructArray == [(x=(x=(a=1,),),), (x=(x=(a=2,),),), (x=(x=(a=3,),),)]
1207+
@test (x -> x.x.x.a).(StructArray(x=StructArray(x=StructArray(a=1:3)))) == [1, 2, 3]
1208+
1209+
@testset "ambiguity check" begin
1210+
test_set = Any[StructArray([1;2+im]),
1211+
1:2,
1212+
(1,2),
1213+
StructArray(@SArray [1;1+2im]),
1214+
(@SArray [1 2]),
1215+
1]
1216+
tested_style = StructArrayStyle[]
1217+
dotaddsub((a, b, c),) = @. a + b - c
1218+
for as in Iterators.product(test_set, test_set, test_set)
1219+
if any(a -> a isa StructArray, as)
1220+
style = Broadcast.combine_styles(as...)
1221+
if !(style in tested_style)
1222+
push!(tested_style, style)
1223+
@test @inferred(dotaddsub(as))::StructArray == dotaddsub(map(collect, as))
1224+
end
1225+
end
1226+
end
1227+
@test length(tested_style) == 4
1228+
end
1229+
1230+
@testset "allocation test" begin
1231+
a = StructArray{ComplexF64}(undef, 1)
1232+
allocated(a) = @allocated a .+ 1
1233+
@test allocated(a) == 2allocated(a.re)
1234+
end
1235+
1236+
@testset "StructStaticArray" begin
1237+
bclog(s) = log.(s)
1238+
test_allocated(f, s) = @test (@allocated f(s)) == 0
1239+
a = @SMatrix [float(i) for i in 1:10, j in 1:10]
1240+
b = @SMatrix [0. for i in 1:10, j in 1:10]
1241+
s = StructArray{ComplexF64}((a , b))
1242+
@test (@inferred bclog(s)) isa typeof(s)
1243+
test_allocated(bclog, s)
1244+
@test abs.(s) .+ ((1,) .+ (1,2,3,4,5,6,7,8,9,10)) isa SMatrix
1245+
bc = Base.broadcasted(+, s, s);
1246+
bc = Base.broadcasted(+, bc, bc, s);
1247+
@test @inferred(Broadcast.axes(bc)) === axes(s)
1248+
end
1249+
1250+
@testset "StructJLArray" begin
1251+
bcabs(a) = abs.(a)
1252+
bcmul2(a) = 2 .* a
1253+
a = StructArray(randn(ComplexF32, 10, 10))
1254+
sa = jl(a)
1255+
backend = StructArrays.GPUArraysCore.backend
1256+
@test @inferred(backend(sa)) === backend(sa.re) === backend(sa.im)
1257+
@test collect(@inferred(bcabs(sa))) == bcabs(a)
1258+
@test @inferred(bcmul2(sa)) isa StructArray
1259+
@test (sa .+= 1) isa StructArray
1260+
end
11581261
end
11591262

11601263
@testset "map" begin

0 commit comments

Comments
 (0)