Skip to content

Commit 2bd33e3

Browse files
pieverN5N3
andauthored
Merge 211 and prepare release (#212)
* Make `StructArrayStyle` track inputs dimension fix #185 * Add test for unstable broadcast * fix test * style tweaks * bump version * style Co-authored-by: N5N3 <[email protected]>
1 parent 8e67e4e commit 2bd33e3

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StructArrays"
22
uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
3-
version = "0.6.4"
3+
version = "0.6.5"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/structarray.jl

+12-6
Original file line numberDiff line numberDiff line change
@@ -445,19 +445,25 @@ end
445445
# broadcast
446446
import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle
447447

448-
struct StructArrayStyle{Style} <: AbstractArrayStyle{Any} end
448+
struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end
449449

450-
@inline combine_style_types(::Type{A}, args...) where A<:AbstractArray =
450+
# Here we define the dimension tracking behavior of StructArrayStyle
451+
function StructArrayStyle{S, M}(::Val{N}) where {S, M, N}
452+
T = S <: AbstractArrayStyle{M} ? typeof(S(Val(N))) : S
453+
return StructArrayStyle{T, N}()
454+
end
455+
456+
@inline combine_style_types(::Type{A}, args...) where {A<:AbstractArray} =
451457
combine_style_types(BroadcastStyle(A), args...)
452-
@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where A<:AbstractArray =
458+
@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where {A<:AbstractArray} =
453459
combine_style_types(Broadcast.result_style(s, BroadcastStyle(A)), args...)
454460
combine_style_types(s::BroadcastStyle) = s
455461

456-
Base.@pure cst(::Type{SA}) where SA = combine_style_types(array_types(SA).parameters...)
462+
Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).parameters...)
457463

458-
BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA))}()
464+
BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}()
459465

460-
Base.similar(bc::Broadcasted{StructArrayStyle{S}}, ::Type{ElType}) where {S<:DefaultArrayStyle,N,ElType} =
466+
Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S<:DefaultArrayStyle, N, ElType} =
461467
isstructtype(ElType) ? similar(StructArray{ElType}, axes(bc)) : similar(Array{ElType}, axes(bc))
462468

463469
# for aliasing analysis during broadcast

test/runtests.jl

+18-1
Original file line numberDiff line numberDiff line change
@@ -926,8 +926,25 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
926926
# used inside of broadcast but we also test it here explicitly
927927
@test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N})
928928

929-
s = StructArray{ComplexF64}((MyArray(rand(2,2)), MyArray(rand(2,2))))
929+
s = StructArray{ComplexF64}((MyArray(rand(2)), MyArray(rand(2))))
930930
@test_throws MethodError s .+ s
931+
932+
# test for dimensionality track
933+
@test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
934+
@test Base.broadcasted(+, s, 1:2) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
935+
@test Base.broadcasted(+, s, reshape(1:2,1,2)) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}}
936+
@test Base.broadcasted(+, reshape(1:2,1,1,2), s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}}
937+
938+
a = StructArray([1;2+im])
939+
b = StructArray([1;;2+im])
940+
@test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b)
941+
@test a .+ Any[1] isa StructArray
942+
943+
# issue #185
944+
A = StructArray(randn(ComplexF64, 3, 3))
945+
B = randn(ComplexF64, 3, 3)
946+
c = StructArray(randn(ComplexF64, 3))
947+
@test (A .= B .* c) === A
931948
end
932949

933950
@testset "staticarrays" begin

0 commit comments

Comments
 (0)