@@ -6,6 +6,7 @@ import Tables, PooledArrays, WeakRefStrings
6
6
using TypedTables: Table
7
7
using DataAPI: refarray, refvalue
8
8
using Adapt: adapt, Adapt
9
+ using JLArrays
9
10
using Test
10
11
11
12
using Documenter: doctest
@@ -1100,17 +1101,39 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)
1100
1101
@test t. b. d isa Array
1101
1102
end
1102
1103
1103
- struct MyArray{T,N} <: AbstractArray{T,N}
1104
- A:: Array{T,N}
1104
+ # The following code defines `MyArray1/2/3` with different `BroadcastStyle`s.
1105
+ # 1. `MyArray1` and `MyArray1` have `similar` defined.
1106
+ # We use them to simulate `BroadcastStyle` overloading `Base.copyto!`.
1107
+ # 2. `MyArray3` has no `similar` defined.
1108
+ # We use it to simulate `BroadcastStyle` overloading `Base.copy`.
1109
+ # 3. Their resolved style could be summaryized as (`-` means conflict)
1110
+ # | MyArray1 | MyArray2 | MyArray3 | Array
1111
+ # -------------------------------------------------------------
1112
+ # MyArray1 | MyArray1 | - | MyArray1 | MyArray1
1113
+ # MyArray2 | - | MyArray2 | - | MyArray2
1114
+ # MyArray3 | MyArray1 | - | MyArray3 | MyArray3
1115
+ # Array | MyArray1 | Array | MyArray3 | Array
1116
+
1117
+ for S in (1 , 2 , 3 )
1118
+ MyArray = Symbol (:MyArray , S)
1119
+ @eval begin
1120
+ struct $ MyArray{T,N} <: AbstractArray{T,N}
1121
+ A:: Array{T,N}
1122
+ end
1123
+ $ MyArray {T} (:: UndefInitializer , sz:: Dims ) where T = $ MyArray (Array {T} (undef, sz))
1124
+ Base. IndexStyle (:: Type{<:$MyArray} ) = IndexLinear ()
1125
+ Base. getindex (A:: $MyArray , i:: Int ) = A. A[i]
1126
+ Base. setindex! (A:: $MyArray , val, i:: Int ) = A. A[i] = val
1127
+ Base. size (A:: $MyArray ) = Base. size (A. A)
1128
+ Base. BroadcastStyle (:: Type{<:$MyArray} ) = Broadcast. ArrayStyle {$MyArray} ()
1129
+ end
1105
1130
end
1106
- MyArray {T} (:: UndefInitializer , sz:: Dims ) where T = MyArray (Array {T} (undef, sz))
1107
- Base. IndexStyle (:: Type{<:MyArray} ) = IndexLinear ()
1108
- Base. getindex (A:: MyArray , i:: Int ) = A. A[i]
1109
- Base. setindex! (A:: MyArray , val, i:: Int ) = A. A[i] = val
1110
- Base. size (A:: MyArray ) = Base. size (A. A)
1111
- Base. BroadcastStyle (:: Type{<:MyArray} ) = Broadcast. ArrayStyle {MyArray} ()
1112
- Base. similar (bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}} , :: Type{ElType} ) where ElType =
1113
- MyArray {ElType} (undef, size (bc))
1131
+ Base. similar (bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}} , :: Type{ElType} ) where ElType =
1132
+ MyArray1 {ElType} (undef, size (bc))
1133
+ Base. similar (bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}} , :: Type{ElType} ) where ElType =
1134
+ MyArray2 {ElType} (undef, size (bc))
1135
+ Base. BroadcastStyle (:: Broadcast.ArrayStyle{MyArray1} , :: Broadcast.ArrayStyle{MyArray3} ) = Broadcast. ArrayStyle {MyArray1} ()
1136
+ Base. BroadcastStyle (:: Broadcast.ArrayStyle{MyArray2} , S:: Broadcast.DefaultArrayStyle ) = S
1114
1137
1115
1138
@testset " broadcast" begin
1116
1139
s = StructArray {ComplexF64} ((rand (2 ,2 ), rand (2 ,2 )))
@@ -1128,19 +1151,44 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
1128
1151
# used inside of broadcast but we also test it here explicitly
1129
1152
@test isa (@inferred (Base. dataids (s)), NTuple{N, UInt} where {N})
1130
1153
1131
- s = StructArray {ComplexF64} ((MyArray (rand (2 )), MyArray (rand (2 ))))
1132
- @test_throws MethodError s .+ s
1133
1154
1155
+ @testset " style conflict check" begin
1156
+ using StructArrays: StructArrayStyle
1157
+ # Make sure we can handle style with similar defined
1158
+ # And we can handle most conflicts
1159
+ # `s1` and `s2` have similar defined, but `s3` does not
1160
+ # `s2` conflicts with `s1` and `s3` and is weaker than `DefaultArrayStyle`
1161
+ s1 = StructArray {ComplexF64} ((MyArray1 (rand (2 )), MyArray1 (rand (2 ))))
1162
+ s2 = StructArray {ComplexF64} ((MyArray2 (rand (2 )), MyArray2 (rand (2 ))))
1163
+ s3 = StructArray {ComplexF64} ((MyArray3 (rand (2 )), MyArray3 (rand (2 ))))
1164
+ s4 = StructArray {ComplexF64} ((rand (2 ), rand (2 )))
1165
+ test_set = Any[s1, s2, s3, s4]
1166
+ tested_style = Any[]
1167
+ dotaddadd ((a, b, c),) = @. a + b + c
1168
+ for as in Iterators. product (test_set, test_set, test_set)
1169
+ ares = map (a-> a. re, as)
1170
+ aims = map (a-> a. im, as)
1171
+ style = Broadcast. combine_styles (ares... )
1172
+ @test Broadcast. combine_styles (as... ) === StructArrayStyle {typeof(style),1} ()
1173
+ if ! (style in tested_style)
1174
+ push! (tested_style, style)
1175
+ if style isa Broadcast. ArrayStyle{MyArray3}
1176
+ @test_throws MethodError dotaddadd (as)
1177
+ else
1178
+ d = StructArray {ComplexF64} ((dotaddadd (ares), dotaddadd (aims)))
1179
+ @test @inferred (dotaddadd (as)):: typeof (d) == d
1180
+ end
1181
+ end
1182
+ end
1183
+ @test length (tested_style) == 5
1184
+ end
1134
1185
# test for dimensionality track
1186
+ s = StructArray {ComplexF64} ((MyArray1 (rand (2 )), MyArray1 (rand (2 ))))
1135
1187
@test Base. broadcasted (+ , s, s) isa Broadcast. Broadcasted{<: Broadcast.AbstractArrayStyle{1} }
1136
1188
@test Base. broadcasted (+ , s, 1 : 2 ) isa Broadcast. Broadcasted{<: Broadcast.AbstractArrayStyle{1} }
1137
1189
@test Base. broadcasted (+ , s, reshape (1 : 2 ,1 ,2 )) isa Broadcast. Broadcasted{<: Broadcast.AbstractArrayStyle{2} }
1138
1190
@test Base. broadcasted (+ , reshape (1 : 2 ,1 ,1 ,2 ), s) isa Broadcast. Broadcasted{<: Broadcast.AbstractArrayStyle{3} }
1139
-
1140
- a = StructArray ([1 ;2 + im])
1141
- b = StructArray ([1 ;;2 + im])
1142
- @test a .+ b == a .+ collect (b) == collect (a) .+ b == collect (a) .+ collect (b)
1143
- @test a .+ Any[1 ] isa StructArray
1191
+ @test Base. broadcasted (+ , s, MyArray1 (rand (2 ))) isa Broadcast. Broadcasted{<: Broadcast.AbstractArrayStyle{Any} }
1144
1192
1145
1193
# issue #185
1146
1194
A = StructArray (randn (ComplexF64, 3 , 3 ))
@@ -1155,6 +1203,61 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
1155
1203
1156
1204
@test identity .(StructArray (x= StructArray (a= 1 : 3 ))):: StructArray == [(x= (a= 1 ,),), (x= (a= 2 ,),), (x= (a= 3 ,),)]
1157
1205
@test (x -> x. x. a). (StructArray (x= StructArray (a= 1 : 3 ))) == [1 , 2 , 3 ]
1206
+ @test identity .(StructArray (x= StructArray (x= StructArray (a= 1 : 3 )))):: StructArray == [(x= (x= (a= 1 ,),),), (x= (x= (a= 2 ,),),), (x= (x= (a= 3 ,),),)]
1207
+ @test (x -> x. x. x. a). (StructArray (x= StructArray (x= StructArray (a= 1 : 3 )))) == [1 , 2 , 3 ]
1208
+
1209
+ @testset " ambiguity check" begin
1210
+ test_set = Any[StructArray ([1 ;2 + im]),
1211
+ 1 : 2 ,
1212
+ (1 ,2 ),
1213
+ StructArray (@SArray [1 ;1 + 2im ]),
1214
+ (@SArray [1 2 ]),
1215
+ 1 ]
1216
+ tested_style = StructArrayStyle[]
1217
+ dotaddsub ((a, b, c),) = @. a + b - c
1218
+ for as in Iterators. product (test_set, test_set, test_set)
1219
+ if any (a -> a isa StructArray, as)
1220
+ style = Broadcast. combine_styles (as... )
1221
+ if ! (style in tested_style)
1222
+ push! (tested_style, style)
1223
+ @test @inferred (dotaddsub (as)):: StructArray == dotaddsub (map (collect, as))
1224
+ end
1225
+ end
1226
+ end
1227
+ @test length (tested_style) == 4
1228
+ end
1229
+
1230
+ @testset " allocation test" begin
1231
+ a = StructArray {ComplexF64} (undef, 1 )
1232
+ allocated (a) = @allocated a .+ 1
1233
+ @test allocated (a) == 2 allocated (a. re)
1234
+ end
1235
+
1236
+ @testset " StructStaticArray" begin
1237
+ bclog (s) = log .(s)
1238
+ test_allocated (f, s) = @test (@allocated f (s)) == 0
1239
+ a = @SMatrix [float (i) for i in 1 : 10 , j in 1 : 10 ]
1240
+ b = @SMatrix [0. for i in 1 : 10 , j in 1 : 10 ]
1241
+ s = StructArray {ComplexF64} ((a , b))
1242
+ @test (@inferred bclog (s)) isa typeof (s)
1243
+ test_allocated (bclog, s)
1244
+ @test abs .(s) .+ ((1 ,) .+ (1 ,2 ,3 ,4 ,5 ,6 ,7 ,8 ,9 ,10 )) isa SMatrix
1245
+ bc = Base. broadcasted (+ , s, s);
1246
+ bc = Base. broadcasted (+ , bc, bc, s);
1247
+ @test @inferred (Broadcast. axes (bc)) === axes (s)
1248
+ end
1249
+
1250
+ @testset " StructJLArray" begin
1251
+ bcabs (a) = abs .(a)
1252
+ bcmul2 (a) = 2 .* a
1253
+ a = StructArray (randn (ComplexF32, 10 , 10 ))
1254
+ sa = jl (a)
1255
+ backend = StructArrays. GPUArraysCore. backend
1256
+ @test @inferred (backend (sa)) === backend (sa. re) === backend (sa. im)
1257
+ @test collect (@inferred (bcabs (sa))) == bcabs (a)
1258
+ @test @inferred (bcmul2 (sa)) isa StructArray
1259
+ @test (sa .+ = 1 ) isa StructArray
1260
+ end
1158
1261
end
1159
1262
1160
1263
@testset " map" begin
0 commit comments