Skip to content

Commit 525d166

Browse files
committed
support named properties for staticarrays
1 parent d9791eb commit 525d166

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

ext/StructArraysStaticArraysExt.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module StructArraysStaticArraysExt
22

33
using StructArrays
4-
using StaticArrays: StaticArray, FieldArray, tuple_prod
4+
using StaticArrays: StaticArray, FieldArray, tuple_prod, SVector, MVector
55

66
"""
77
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
@@ -22,7 +22,16 @@ which subtypes `FieldArray`.
2222
end
2323
end
2424
StructArrays.createinstance(::Type{T}, args...) where {T<:StaticArray} = T(args)
25-
StructArrays.component(s::StaticArray, i) = getindex(s, i)
25+
StructArrays.component(s::StaticArray, i::Integer) = getindex(s, i)
26+
27+
function StructArrays.component(s::StructArray{<:Union{SVector,MVector}}, key::Symbol)
28+
i = key == :x ? 1 :
29+
key == :y ? 2 :
30+
key == :z ? 3 :
31+
key == :w ? 4 :
32+
throw(ArgumentError("invalid key $key"))
33+
StructArrays.component(s, i)
34+
end
2635

2736
# invoke general fallbacks for a `FieldArray` type.
2837
@inline function StructArrays.staticschema(T::Type{<:FieldArray})

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,6 +1405,12 @@ end
14051405
@test StructArrays.components(x) == ([1.0,2.0], [2.0,3.0])
14061406
@test x .+ y == StructArray([StaticVectorType{2}(Float64[2*i+1;2*i+3]) for i = 1:2])
14071407
end
1408+
for StaticVectorType = [SVector, MVector]
1409+
x = @inferred StructArray([StaticVectorType{2}(Float64[i;i+1]) for i = 1:2])
1410+
# numbered and named property access:
1411+
@test x.:1 == [1.0,2.0]
1412+
@test x.y == [2.0,3.0]
1413+
end
14081414
# test broadcast + components for general arrays
14091415
for StaticArrayType = [SArray, MArray, SizedArray]
14101416
x = @inferred StructArray([StaticArrayType{Tuple{1,2}}(ones(1,2) .+ i) for i = 0:1])

0 commit comments

Comments
 (0)