Skip to content

Commit d2dde28

Browse files
N5N3piever
andcommitted
Adopt suggestions and add more internal doc/ comments.
Co-Authored-By: Pietro Vertechi <[email protected]>
1 parent 9ba2769 commit d2dde28

File tree

4 files changed

+36
-10
lines changed

4 files changed

+36
-10
lines changed

src/StructArrays.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ import Adapt
3030
Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s)
3131

3232
# for GPU broadcast
33-
import GPUArraysCore: backend
34-
function backend(::Type{T}) where {T<:StructArray}
35-
backs = map(backend, fieldtypes(array_types(T)))
33+
import GPUArraysCore
34+
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
35+
backs = map(GPUArraysCore.backend, fieldtypes(array_types(T)))
3636
all(Base.Fix2(===, backs[1]), tail(backs)) || error("backend mismatch!")
3737
return backs[1]
3838
end

src/interface.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,8 @@ end
5151

5252
createinstance(::Type{T}, args...) where {T<:Tup} = T(args)
5353

54-
createinstance(::Type{T}) where {T} = (x...) -> createinstance(T, x...)
54+
struct Instantiator{T} end
55+
56+
Instantiator(::Type{T}) where {T} = Instantiator{T}()
57+
58+
(::Instantiator{T})(args...) where {T} = createinstance(T, args...)

src/staticarrays_support.jl

+12-3
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,32 @@ function Broadcast.instantiate(bc::Broadcasted{StructStaticArrayStyle{M}}) where
3535
bc′ = Broadcast.instantiate(replace_structarray(bc))
3636
return convert(Broadcasted{StructStaticArrayStyle{M}}, bc′)
3737
end
38-
# This looks costy, but compiler should be able to optimize them away
38+
# This looks costly, but the compiler should be able to optimize them away
3939
Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(replace_structarray(bc))
4040

4141
to_staticstyle(@nospecialize(x::Type)) = x
4242
to_staticstyle(::Type{StructStaticArrayStyle{N}}) where {N} = StaticArrayStyle{N}
43+
44+
"""
45+
replace_structarray(bc::Broadcasted)
46+
47+
An internal function transforms the `Broadcasted` with `StructArray` into
48+
an equivalent one without it. This is not a must if the root `BroadcastStyle`
49+
supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
50+
e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
51+
"""
4352
function replace_structarray(bc::Broadcasted{Style}) where {Style}
4453
args = replace_structarray_args(bc.args)
4554
return Broadcasted{to_staticstyle(Style)}(bc.f, args, nothing)
4655
end
4756
function replace_structarray(A::StructArray)
48-
f = createinstance(eltype(A))
57+
f = Instantiator(eltype(A))
4958
args = Tuple(components(A))
5059
return Broadcasted{StaticArrayStyle{ndims(A)}}(f, args, nothing)
5160
end
5261
replace_structarray(@nospecialize(A)) = A
5362

54-
replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(Base.tail(args))...)
63+
replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(tail(args))...)
5564
replace_structarray_args(::Tuple{}) = ()
5665

5766
# StaticArrayStyle has no similar defined.

test/runtests.jl

+16-3
Original file line numberDiff line numberDiff line change
@@ -1091,6 +1091,19 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)
10911091
@test t.b.d isa Array
10921092
end
10931093

1094+
# The following code defines `MyArray1/2/3` with different `BroadcastStyle`s.
1095+
# 1. `MyArray1` and `MyArray1` have `similar` defined.
1096+
# We use them to simulate `BroadcastStyle` overloading `Base.copyto!`.
1097+
# 2. `MyArray3` has no `similar` defined.
1098+
# We use it to simulate `BroadcastStyle` overloading `Base.copy`.
1099+
# 3. Their resolved style could be summaryized as (`-` means conflict)
1100+
# | MyArray1 | MyArray2 | MyArray3 | Array
1101+
# -------------------------------------------------------------
1102+
# MyArray1 | MyArray1 | - | MyArray1 | MyArray1
1103+
# MyArray2 | - | MyArray2 | - | MyArray2
1104+
# MyArray3 | MyArray1 | - | MyArray3 | MyArray3
1105+
# Array | MyArray1 | Array | MyArray3 | Array
1106+
10941107
for S in (1, 2, 3)
10951108
MyArray = Symbol(:MyArray, S)
10961109
@eval begin
@@ -1129,9 +1142,9 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
11291142
@test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N})
11301143

11311144
# Make sure we can handle style with similar defined
1132-
# And we can handle most conflict
1133-
# s1 and s2 has similar defined, but s3 not
1134-
# s2 are conflict with s1 and s3. (And it's weaker than DefaultArrayStyle)
1145+
# And we can handle most conflicts
1146+
# `s1` and `s2` have similar defined, but `s3` does not
1147+
# `s2` conflicts with `s1` and `s3` and is weaker than `DefaultArrayStyle`
11351148
s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2))))
11361149
s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2))))
11371150
s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2))))

0 commit comments

Comments
 (0)