Skip to content

Commit 8b665db

Browse files
pieverjishnub
andauthored
improve conversion to sparse arrays (#289)
* sparse conversion * map sparse over components * move tests up * Move to package extension and extend issparse * formatting * use createinstance --------- Co-authored-by: Jishnu Bhattacharya <[email protected]>
1 parent 15b044b commit 8b665db

File tree

4 files changed

+41
-4
lines changed

4 files changed

+41
-4
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,20 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
88
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
99
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
10+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1011
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1112
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1213

1314
[weakdeps]
1415
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
1516
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
17+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1618
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1719

1820
[extensions]
1921
StructArraysAdaptExt = "Adapt"
2022
StructArraysGPUArraysCoreExt = "GPUArraysCore"
23+
StructArraysSparseArraysExt = "SparseArrays"
2124
StructArraysStaticArraysExt = "StaticArrays"
2225

2326
[compat]
@@ -34,6 +37,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3437
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3538
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
3639
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
40+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
3741
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
3842
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
3943
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -43,4 +47,4 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
4347
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"
4448

4549
[targets]
46-
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays", "GPUArraysCore", "Adapt"]
50+
test = ["Adapt", "Documenter", "GPUArraysCore", "JLArrays", "LinearAlgebra", "OffsetArrays", "PooledArrays", "SparseArrays", "StaticArrays", "Test", "TypedTables", "WeakRefStrings"]

ext/StructArraysSparseArraysExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module StructArraysSparseArraysExt
2+
3+
using StructArrays: StructArray, components, createinstance
4+
import SparseArrays: sparse, issparse
5+
6+
function sparse(S::StructArray{T}) where {T}
7+
sparse_components = map(sparse, components(S))
8+
return createinstance.(T, sparse_components...)
9+
end
10+
11+
issparse(S::StructArray) = all(issparse, components(S))
12+
13+
end

src/StructArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ end
2828
@static if !isdefined(Base, :get_extension)
2929
include("../ext/StructArraysAdaptExt.jl")
3030
include("../ext/StructArraysGPUArraysCoreExt.jl")
31+
include("../ext/StructArraysSparseArraysExt.jl")
3132
include("../ext/StructArraysStaticArraysExt.jl")
3233
end
3334

test/runtests.jl

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using DataAPI: refarray, refvalue
88
using Adapt: adapt, Adapt
99
using JLArrays
1010
using GPUArraysCore: backend
11+
using LinearAlgebra
1112
using Test
1213
using SparseArrays
1314

@@ -613,9 +614,10 @@ end
613614
A = spzeros(3)
614615
B = spzeros(3)
615616
S = StructArray{Complex{eltype(A)}}((A,B))
616-
fill!(S, 0)
617-
@test all(iszero, A)
618-
@test all(iszero, B)
617+
fill!(S, 2+3im)
618+
@test all(==(2), A)
619+
@test all(==(3), B)
620+
@test issparse(S)
619621
end
620622
end
621623

@@ -1144,6 +1146,23 @@ end
11441146
end
11451147
end
11461148

1149+
@testset "sparse" begin
1150+
@testset "Vector" begin
1151+
v = [1,0,2]
1152+
sv = StructArray{Complex{Int}}((v, v))
1153+
spv = @inferred sparse(sv)
1154+
@test spv isa SparseVector{eltype(sv)}
1155+
@test spv == sv
1156+
end
1157+
@testset "Matrix" begin
1158+
d = Diagonal(Float64[1:4;])
1159+
sa = StructArray{ComplexF64}((d, d))
1160+
sp = @inferred sparse(sa)
1161+
@test sp isa SparseMatrixCSC{eltype(sa)}
1162+
@test sp == sa
1163+
end
1164+
end
1165+
11471166
struct ArrayConverter end
11481167

11491168
Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)

0 commit comments

Comments
 (0)