@@ -29,16 +29,30 @@ StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{
29
29
StructArrays. createinstance (T:: Type{<:FieldArray} , args... ) = invoke (createinstance, Tuple{Type{<: Any }, Vararg}, T, args... )
30
30
31
31
# Broadcast overload
32
- using StaticArraysCore: StaticArrayStyle
33
- import StaticArraysCore: Size, is_staticarray_like, similar_type
32
+ using StaticArraysCore: StaticArrayStyle, similar_type
34
33
StructStaticArrayStyle{N} = StructArrayStyle{StaticArrayStyle{N}, N}
35
34
function Broadcast. instantiate (bc:: Broadcasted{StructStaticArrayStyle{M}} ) where {M}
36
- bc′ = Broadcast. instantiate (convert (Broadcasted{StaticArrayStyle{M}}, bc))
35
+ bc′ = Broadcast. instantiate (replace_structarray ( bc))
37
36
return convert (Broadcasted{StructStaticArrayStyle{M}}, bc′)
38
37
end
39
- function Broadcast. _axes (bc:: Broadcasted{StructStaticArrayStyle{M}} , :: Nothing ) where {M}
40
- return Broadcast. _axes (convert (Broadcasted{StaticArrayStyle{M}}, bc), nothing )
38
+ # This looks costy, but compiler should be able to optimize them away
39
+ Broadcast. _axes (bc:: Broadcasted{<:StructStaticArrayStyle} , :: Nothing ) = axes (replace_structarray (bc))
40
+
41
+ to_staticstyle (@nospecialize (x:: Type )) = x
42
+ to_staticstyle (:: Type{StructStaticArrayStyle{N}} ) where {N} = StaticArrayStyle{N}
43
+ function replace_structarray (bc:: Broadcasted{Style} ) where {Style}
44
+ args = replace_structarray_args (bc. args)
45
+ return Broadcasted {to_staticstyle(Style)} (bc. f, args, nothing )
46
+ end
47
+ function replace_structarray (A:: StructArray )
48
+ f = createinstance (eltype (A))
49
+ args = Tuple (components (A))
50
+ return Broadcasted {StaticArrayStyle{ndims(A)}} (f, args, nothing )
41
51
end
52
+ replace_structarray (@nospecialize (A)) = A
53
+
54
+ replace_structarray_args (args:: Tuple ) = (replace_structarray (args[1 ]), replace_structarray_args (Base. tail (args))... )
55
+ replace_structarray_args (:: Tuple{} ) = ()
42
56
43
57
# StaticArrayStyle has no similar defined.
44
58
# Overload `Base.copy` instead.
48
62
isnonemptystructtype (ET) || return sa
49
63
elements = Tuple (sa)
50
64
arrs = ntuple (Val (fieldcount (ET))) do i
51
- similar_type (sa, fieldtype (ET, i), Size (sa) )(_getfields (elements, i))
65
+ similar_type (sa, fieldtype (ET, i))(_getfields (elements, i))
52
66
end
53
67
return StructArray {ET} (arrs)
54
68
end
60
74
return map (Base. Fix2 (getfield, i), x)
61
75
end
62
76
end
63
-
64
- Size (:: Type{SA} ) where {SA<: StructArray } = Size (fieldtype (array_types (SA), 1 ))
65
- is_staticarray_like (x:: StructArray ) = any (is_staticarray_like, components (x))
66
- function similar_type (:: Type{SA} , :: Type{T} , s:: Size{S} ) where {SA<: StructArray , T, S}
67
- return similar_type (fieldtype (array_types (SA), 1 ), T, s)
68
- end
0 commit comments