Skip to content

Commit f8c988e

Browse files
committed
Use StaticArrayStyle instead.
1 parent a92bc41 commit f8c988e

File tree

2 files changed

+38
-13
lines changed

2 files changed

+38
-13
lines changed

src/staticarrays_support.jl

+21-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
import StaticArrays: StaticArray, FieldArray, tuple_prod, StaticArrayStyle
1+
using StaticArrays: StaticArrays, StaticArray, FieldArray, tuple_prod, StaticArrayStyle
2+
import StaticArrays: Size
3+
import Base.Broadcast: instantiate
24

35
"""
46
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
@@ -28,10 +30,24 @@ end
2830
StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i)
2931
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)
3032

33+
@static if isdefined(StaticArrays, :static_combine_axes)
3134
# StaticArrayStyle has no similar defined.
32-
# Convert to `DefaultArrayStyle` to return a sized (Struct)Array.
33-
# TODO: return a StaticArray?
34-
function Base.copy(bc::Broadcasted{StructArrayStyle{StaticArrayStyle{N},N}}) where {N}
35-
bc′ = convert(Broadcasted{StructArrayStyle{Broadcast.DefaultArrayStyle{N},N}}, bc)
35+
# Convert to `StaticArrayStyle` to return a StaticArray instead.
36+
StructStaticArrayStyle{N} = StructArrayStyle{StaticArrayStyle{N}, N}
37+
@inline function Base.copy(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M}
38+
bc′ = convert(Broadcasted{StaticArrayStyle{M}}, bc)
3639
return copy(bc′)
3740
end
41+
function instantiate(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M}
42+
bc′ = instantiate(convert(Broadcasted{StaticArrayStyle{M}}, bc))
43+
return convert(Broadcasted{StructStaticArrayStyle{M}}, bc′)
44+
end
45+
function Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing)
46+
return StaticArrays.static_combine_axes(bc.args...)
47+
end
48+
Size(::Type{SA}) where {SA<:StructArray} = Size(fieldtype(fieldtype(SA, 1), 1))
49+
StaticArrays.isstatic(::SA) where {SA<:StructArray} = cst(SA) isa StaticArrayStyle
50+
function StaticArrays.similar_type(::Type{SA}, ::Type{T}, s::Size{S}) where {SA<:StructArray, T, S}
51+
return StaticArrays.similar_type(fieldtype(fieldtype(SA, 1), 1), T, s)
52+
end
53+
end

test/runtests.jl

+17-8
Original file line numberDiff line numberDiff line change
@@ -977,26 +977,35 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
977977
@test @inferred(broadcast(el -> el.a, v)) == ["s1", "s2"]
978978

979979
# ambiguity check (can we do this better?)
980-
function _test(a, b, c)
980+
function _test(a, b, c, T = StructArray)
981981
if a isa StructArray || b isa StructArray || c isa StructArray
982982
d = @inferred a .+ b .- c
983983
@test d == collect(a) .+ collect(b) .- collect(c)
984-
@test d isa StructArray
984+
@test d isa T
985985
end
986986
end
987-
testset = (StructArray([1;2+im]),
987+
testset = Any[StructArray([1;2+im]),
988988
StructArray([1 2+im]),
989989
1:2,
990990
(1,2),
991-
(@SArray [1 2]),
992-
StructArray(@SArray [1 1+2im]))
991+
(@SArray [1 2])]
993992
for aa in testset, bb in testset, cc in testset
994993
_test(aa, bb, cc)
995994
end
995+
if isdefined(StaticArrays, :static_combine_axes)
996+
testset = Any[StructArray(@SArray [1 1+2im]), (1,2), StructArray(@SArray [1;1+2im])]
997+
for aa in testset, bb in testset, cc in testset
998+
_test(aa, bb, cc, StaticArray)
999+
end
1000+
end
1001+
end
9961002

997-
a = @SArray randn(3,3);
998-
b = StructArray{ComplexF64}((a,a))
999-
@test a[:,1] .+ b isa StructArray && (a[:,1] .+ b).re isa SizedMatrix
1003+
function struct_static_allocated_test()
1004+
s = StructArray{ComplexF64}((SVector(1., 2., 3.), SVector(0., 0., 0.)))
1005+
return broadcast(log, s)
1006+
end
1007+
if isdefined(StaticArrays, :static_combine_axes)
1008+
@test (@allocated struct_static_allocated_test()) === 0
10001009
end
10011010

10021011
@testset "staticarrays" begin

0 commit comments

Comments
 (0)