Skip to content

Commit 7171bab

Browse files
committed
move StructStaticArray broadcast to ext
1 parent c81c308 commit 7171bab

6 files changed

+89
-84
lines changed

Project.toml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,18 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1010

1111
[weakdeps]
1212
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
13-
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
13+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1414

1515
[extensions]
1616
StructArraysGPUArraysCoreExt = "GPUArraysCore"
17-
StructArraysStaticArraysCoreExt = "StaticArraysCore"
17+
StructArraysStaticArraysExt = "StaticArrays"
1818

1919
[compat]
2020
Adapt = "1, 2, 3"
2121
ConstructionBase = "1"
2222
DataAPI = "1"
2323
GPUArraysCore = "0.1.2"
2424
StaticArrays = "1.5.6"
25-
StaticArraysCore = "1.3"
2625
Tables = "1"
2726
julia = "1.6"
2827

@@ -34,10 +33,9 @@ OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
3433
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
3534
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3635
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
37-
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
3836
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3937
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
4038
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"
4139

4240
[targets]
43-
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays", "GPUArraysCore", "StaticArraysCore"]
41+
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays", "GPUArraysCore"]

ext/StructArraysStaticArraysCoreExt.jl

Lines changed: 0 additions & 77 deletions
This file was deleted.

ext/StructArraysStaticArraysExt.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
module StructArraysStaticArraysExt
2+
3+
using StructArrays
4+
using StaticArrays: StaticArray, FieldArray, tuple_prod
5+
6+
"""
7+
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
8+
9+
The `staticschema` of a `StaticArray` element type is the `staticschema` of the underlying `Tuple`.
10+
```julia
11+
julia> StructArrays.staticschema(SVector{2, Float64})
12+
Tuple{Float64, Float64}
13+
```
14+
The one exception to this rule is `<:StaticArrays.FieldArray`, since `FieldArray` is based on a
15+
struct. In this case, `staticschema(<:FieldArray)` returns the `staticschema` for the struct
16+
which subtypes `FieldArray`.
17+
"""
18+
@generated function StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
19+
return quote
20+
Base.@_inline_meta
21+
return NTuple{$(tuple_prod(S)), T}
22+
end
23+
end
24+
StructArrays.createinstance(::Type{T}, args...) where {T<:StaticArray} = T(args)
25+
StructArrays.component(s::StaticArray, i) = getindex(s, i)
26+
27+
# invoke general fallbacks for a `FieldArray` type.
28+
@inline function StructArrays.staticschema(T::Type{<:FieldArray})
29+
invoke(StructArrays.staticschema, Tuple{Type{<:Any}}, T)
30+
end
31+
StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i)
32+
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(StructArrays.createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)
33+
34+
# Broadcast overload
35+
using StaticArrays: StaticArrayStyle, similar_type, Size, SOneTo
36+
using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype, __broadcast
37+
using StructArrays: isnonemptystructtype
38+
using Base.Broadcast: Broadcasted
39+
40+
# StaticArrayStyle has no similar defined.
41+
# Overload `try_struct_copy` instead.
42+
@inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M}
43+
flat = broadcast_flatten(bc); as = flat.args; f = flat.f
44+
argsizes = broadcast_sizes(as...)
45+
ax = axes(bc)
46+
ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug at `StaticArrays.jl`.")
47+
return _broadcast(f, Size(map(length, ax)), argsizes, as...)
48+
end
49+
50+
@inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where {newsize}
51+
first_staticarray = first_statictype(a...)
52+
elements, ET = if prod(newsize) == 0
53+
# Use inference to get eltype in empty case (see also comments in _map)
54+
eltys = Tuple{map(eltype, a)...}
55+
(), Core.Compiler.return_type(f, eltys)
56+
else
57+
temp = __broadcast(f, sz, s, a...)
58+
temp, eltype(temp)
59+
end
60+
if isnonemptystructtype(ET)
61+
@static if VERSION >= v"1.7"
62+
arrs = ntuple(Val(fieldcount(ET))) do i
63+
@inbounds similar_type(first_staticarray, fieldtype(ET, i), sz)(_getfields(elements, i))
64+
end
65+
else
66+
similarET(::Type{SA}, ::Type{T}) where {SA, T} = i -> @inbounds similar_type(SA, fieldtype(T, i), sz)(_getfields(elements, i))
67+
arrs = ntuple(similarET(first_staticarray, ET), Val(fieldcount(ET)))
68+
end
69+
return StructArray{ET}(arrs)
70+
end
71+
@inbounds return similar_type(first_staticarray, ET, sz)(elements)
72+
end
73+
74+
@inline function _getfields(x::Tuple, i::Int)
75+
if @generated
76+
return Expr(:tuple, (:(getfield(x[$j], i)) for j in 1:fieldcount(x))...)
77+
else
78+
return map(Base.Fix2(getfield, i), x)
79+
end
80+
end
81+
82+
end

src/StructArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x
3131

3232
@static if !isdefined(Base, :get_extension)
3333
include("../ext/StructArraysGPUArraysCoreExt.jl")
34-
include("../ext/StructArraysStaticArraysCoreExt.jl")
34+
include("../ext/StructArraysStaticArraysExt.jl")
3535
end
3636

3737
end # module

src/structarray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ See also [`always_struct_broadcast`](@ref).
551551
"""
552552
try_struct_copy(bc::Broadcasted) = copy(bc)
553553

554-
function Base.copy(bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
554+
@inline function Base.copy(bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
555555
if always_struct_broadcast(S())
556556
return invoke(copy, Tuple{Broadcasted}, bc)
557557
else

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1297,8 +1297,10 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
12971297

12981298
@testset "allocation test" begin
12991299
a = StructArray{ComplexF64}(undef, 1)
1300+
sa = StructArray{ComplexF64}((SizedVector{1}(a.re), SizedVector{1}(a.re)))
13001301
allocated(a) = @allocated a .+ 1
13011302
@test allocated(a) == 2allocated(a.re)
1303+
@test allocated(sa) == 2allocated(sa.re)
13021304
allocated2(a) = @allocated a .= complex.(a.im, a.re)
13031305
@test allocated2(a) == 0
13041306
end

0 commit comments

Comments
 (0)