Skip to content

Commit 572393e

Browse files
committed
Add GPU broadcast support.
1 parent 470a154 commit 572393e

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

Project.toml

+6-3
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,26 @@ version = "0.6.11"
55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
8+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
89
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
910
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1011

1112
[compat]
1213
Adapt = "1, 2, 3"
1314
DataAPI = "1"
14-
StaticArrays = "1"
15+
GPUArraysCore = "= 0.1.2"
16+
StaticArrays = ">= 1.4.2"
1517
Tables = "1"
16-
julia = "1.3"
18+
julia = "1.6"
1719

1820
[extras]
1921
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
22+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
2023
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
2124
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
2225
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2326
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
2427
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"
2528

2629
[targets]
27-
test = ["Test", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"]
30+
test = ["Test", "JLArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"]

src/StructArrays.jl

+8
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,12 @@ end
2929
import Adapt
3030
Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s)
3131

32+
# for GPU broadcast
33+
import GPUArraysCore: backend
34+
function backend(::Type{T}) where {T<:StructArray}
35+
backs = map(backend, fieldtypes(array_types(T)))
36+
all(Base.Fix2(===, backs[1]), tail(backs)) || error("backend mismatch!")
37+
return backs[1]
38+
end
39+
3240
end # module

test/runtests.jl

+11
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import Tables, PooledArrays, WeakRefStrings
66
using TypedTables: Table
77
using DataAPI: refarray, refvalue
88
using Adapt: adapt, Adapt
9+
using JLArrays
910
using Test
1011

1112
using Documenter: doctest
@@ -1178,6 +1179,16 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
11781179
@test (@inferred bclog(s)) isa typeof(s)
11791180
test_allocated(bclog, s)
11801181
end
1182+
1183+
@testset "StructJLArray" begin
1184+
bcabs(a) = abs.(a)
1185+
bcmul2(a) = 2 .* a
1186+
a = StructArray(randn(ComplexF32, 10, 10))
1187+
sa = jl(a)
1188+
@test collect(@inferred(bcabs(sa))) == bcabs(a)
1189+
@test @inferred(bcmul2(sa)) isa StructArray
1190+
@test (sa .+= 1) isa StructArray
1191+
end
11811192
end
11821193

11831194
@testset "map" begin

0 commit comments

Comments
 (0)