Skip to content

Commit 9eeb22b

Browse files
committed
Use weakdeps on Julia v1.9
1 parent f23f3a8 commit 9eeb22b

5 files changed

+53
-13
lines changed

Project.toml

+10
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,16 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
99
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1010
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1111

12+
[weakdeps]
13+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
14+
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
15+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
16+
17+
[extensions]
18+
StructArraysGPUArraysCoreExt = "GPUArraysCore"
19+
StructArraysStaticArraysCoreExt = "StaticArraysCore"
20+
StructArraysTablesExt = "Tables"
21+
1222
[compat]
1323
Adapt = "1, 2, 3"
1424
DataAPI = "1"

ext/StructArraysGPUArraysCoreExt.jl

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
module StructArraysGPUArraysCoreExt
2+
3+
4+
using StructArrays
5+
import GPUArraysCore
6+
7+
# for GPU broadcast
8+
import GPUArraysCore
9+
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
10+
backends = map_params(GPUArraysCore.backend, array_types(T))
11+
backend, others = backends[1], tail(backends)
12+
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
13+
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
14+
return backend
15+
end
16+
always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true
17+
18+
19+
end # module

src/staticarrays_support.jl renamed to ext/StructArraysStaticArraysCoreExt.jl

+11
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
module StructArraysStaticArraysCoreExt
2+
3+
4+
using StructArrays
5+
using StructArrays: StructArrayStyle
6+
7+
using Base.Broadcast: Broadcasted
8+
19
using StaticArraysCore: StaticArray, FieldArray, tuple_prod
210

311
"""
@@ -66,3 +74,6 @@ end
6674
return map(Base.Fix2(getfield, i), x)
6775
end
6876
end
77+
78+
79+
end # module

src/tables.jl renamed to ext/StructArraysTablesExt.jl

+7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
module StructArraysTablesExt
2+
3+
4+
using StructArrays
15
import Tables
26

37
Tables.isrowtable(::Type{<:StructArray}) = true
@@ -38,3 +42,6 @@ for (f, g) in zip((:append!, :prepend!), (:push!, :pushfirst!))
3842
end
3943
end
4044
end
45+
46+
47+
end # module

src/StructArrays.jl

+6-13
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@ include("utils.jl")
1212
include("collect.jl")
1313
include("sort.jl")
1414
include("lazy.jl")
15-
include("tables.jl")
16-
include("staticarrays_support.jl")
15+
16+
@static if !isdefined(Base, :get_extension)
17+
include("../ext/StructArraysGPUArraysCoreExt.jl")
18+
include("../ext/StructArraysTablesExt.jl")
19+
include("../ext/StructArraysStaticArraysCoreExt.jl")
20+
end
1721

1822
# Implement refarray and refvalue to deal with pooled arrays and weakrefstrings effectively
1923
import DataAPI: refarray, refvalue
@@ -29,15 +33,4 @@ end
2933
import Adapt
3034
Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s)
3135

32-
# for GPU broadcast
33-
import GPUArraysCore
34-
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
35-
backends = map_params(GPUArraysCore.backend, array_types(T))
36-
backend, others = backends[1], tail(backends)
37-
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
38-
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
39-
return backend
40-
end
41-
always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true
42-
4336
end # module

0 commit comments

Comments
 (0)