diff --git a/ext/StructArraysStaticArraysExt.jl b/ext/StructArraysStaticArraysExt.jl index 0c84f5b7..14df4325 100644 --- a/ext/StructArraysStaticArraysExt.jl +++ b/ext/StructArraysStaticArraysExt.jl @@ -1,7 +1,7 @@ module StructArraysStaticArraysExt using StructArrays -using StaticArrays: StaticArray, FieldArray, tuple_prod +using StaticArrays: StaticArray, FieldArray, tuple_prod, SVector, MVector """ StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T} @@ -22,7 +22,16 @@ which subtypes `FieldArray`. end end StructArrays.createinstance(::Type{T}, args...) where {T<:StaticArray} = T(args) -StructArrays.component(s::StaticArray, i) = getindex(s, i) +StructArrays.component(s::StaticArray, i::Integer) = getindex(s, i) + +function StructArrays.component(s::StructArray{<:Union{SVector,MVector}}, key::Symbol) + i = key == :x ? 1 : + key == :y ? 2 : + key == :z ? 3 : + key == :w ? 4 : + throw(ArgumentError("invalid key $key")) + StructArrays.component(s, i) +end # invoke general fallbacks for a `FieldArray` type. @inline function StructArrays.staticschema(T::Type{<:FieldArray}) diff --git a/test/runtests.jl b/test/runtests.jl index ed6f573a..30b0ccb6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1405,6 +1405,12 @@ end @test StructArrays.components(x) == ([1.0,2.0], [2.0,3.0]) @test x .+ y == StructArray([StaticVectorType{2}(Float64[2*i+1;2*i+3]) for i = 1:2]) end + for StaticVectorType = [SVector, MVector] + x = @inferred StructArray([StaticVectorType{2}(Float64[i;i+1]) for i = 1:2]) + # numbered and named property access: + @test x.:1 == [1.0,2.0] + @test x.y == [2.0,3.0] + end # test broadcast + components for general arrays for StaticArrayType = [SArray, MArray, SizedArray] x = @inferred StructArray([StaticArrayType{Tuple{1,2}}(ones(1,2) .+ i) for i = 0:1])