Skip to content

Commit c81c308

Browse files
oschulzN5N3
authored andcommitted
Use weakdeps on Julia v1.9
Update Project.toml
1 parent 99f0556 commit c81c308

File tree

4 files changed

+44
-13
lines changed

4 files changed

+44
-13
lines changed

Project.toml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,15 @@ version = "0.6.16"
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
88
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
9+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
10+
11+
[weakdeps]
912
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1013
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
11-
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
14+
15+
[extensions]
16+
StructArraysGPUArraysCoreExt = "GPUArraysCore"
17+
StructArraysStaticArraysCoreExt = "StaticArraysCore"
1218

1319
[compat]
1420
Adapt = "1, 2, 3"
@@ -22,14 +28,16 @@ julia = "1.6"
2228

2329
[extras]
2430
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
31+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
2532
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
2633
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
2734
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
2835
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2936
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
37+
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
3038
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3139
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
3240
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"
3341

3442
[targets]
35-
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays"]
43+
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays", "GPUArraysCore", "StaticArraysCore"]

ext/StructArraysGPUArraysCoreExt.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module StructArraysGPUArraysCoreExt
2+
3+
using StructArrays
4+
using StructArrays: map_params, array_types
5+
6+
using Base: tail
7+
8+
import GPUArraysCore
9+
10+
# for GPU broadcast
11+
import GPUArraysCore
12+
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
13+
backends = map_params(GPUArraysCore.backend, array_types(T))
14+
backend, others = backends[1], tail(backends)
15+
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
16+
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
17+
return backend
18+
end
19+
StructArrays.always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true
20+
21+
end # module

src/staticarrays_support.jl renamed to ext/StructArraysStaticArraysCoreExt.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
module StructArraysStaticArraysCoreExt
2+
3+
using StructArrays
4+
using StructArrays: StructArrayStyle, createinstance, replace_structarray, isnonemptystructtype
5+
6+
using Base.Broadcast: Broadcasted
7+
18
using StaticArraysCore: StaticArray, FieldArray, tuple_prod
29

310
"""
@@ -40,7 +47,7 @@ Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(rep
4047

4148
# StaticArrayStyle has no similar defined.
4249
# Overload `Base.copy` instead.
43-
@inline function try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M}
50+
@inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M}
4451
sa = copy(bc)
4552
ET = eltype(sa)
4653
isnonemptystructtype(ET) || return sa
@@ -66,3 +73,5 @@ end
6673
return map(Base.Fix2(getfield, i), x)
6774
end
6875
end
76+
77+
end # module

src/StructArrays.jl

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ include("collect.jl")
1414
include("sort.jl")
1515
include("lazy.jl")
1616
include("tables.jl")
17-
include("staticarrays_support.jl")
1817

1918
# Implement refarray and refvalue to deal with pooled arrays and weakrefstrings effectively
2019
import DataAPI: refarray, refvalue
@@ -30,15 +29,9 @@ end
3029
import Adapt
3130
Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s)
3231

33-
# for GPU broadcast
34-
import GPUArraysCore
35-
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
36-
backends = map_params(GPUArraysCore.backend, array_types(T))
37-
backend, others = backends[1], tail(backends)
38-
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
39-
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
40-
return backend
32+
@static if !isdefined(Base, :get_extension)
33+
include("../ext/StructArraysGPUArraysCoreExt.jl")
34+
include("../ext/StructArraysStaticArraysCoreExt.jl")
4135
end
42-
always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true
4336

4437
end # module

0 commit comments

Comments
 (0)