Skip to content

Commit 64c64b2

Browse files
eschnetttkf
andauthored
Provide constructors from generators (#792)
* Add SVector constructors from generators * Add SMatrix constructors from generators * Add SArray constructor from generators * Implement sacollect. Add error messages. This also much simplifies the code. * Simplify sacollect doc string * Make test case backward compatible with Julia <1.5 * Correct test case * Update src/SArray.jl Co-authored-by: Takafumi Arakaki <[email protected]> * Define generator constructors for StaticArray instead of just SArray * Do not export sacollect * Update sacollect tests Co-authored-by: Takafumi Arakaki <[email protected]>
1 parent c75c664 commit 64c64b2

File tree

8 files changed

+139
-1
lines changed

8 files changed

+139
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1111
julia = "1"
1212

1313
[extras]
14+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
1415
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1516
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
16-
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
1717

1818
[targets]
1919
test = ["InteractiveUtils", "Test", "BenchmarkTools"]

src/SArray.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,62 @@ end
5050
end
5151
end
5252

53+
54+
@noinline function generator_too_short_error(inds::CartesianIndices, i::CartesianIndex)
55+
error("Generator produced too few elements: Expected exactly $(shape_string(inds)) elements, but generator stopped at $(shape_string(i))")
56+
end
57+
@noinline function generator_too_long_error(inds::CartesianIndices)
58+
error("Generator produced too many elements: Expected exactly $(shape_string(inds)) elements, but generator yields more")
59+
end
60+
61+
shape_string(inds::CartesianIndices) = join(length.(inds.indices), '×')
62+
shape_string(inds::CartesianIndex) = join(Tuple(inds), '×')
63+
64+
@inline throw_if_nothing(x, inds, i) =
65+
(x === nothing && generator_too_short_error(inds, i); x)
66+
67+
@generated function sacollect(::Type{SA}, gen) where {SA <: StaticArray{S}} where {S <: Tuple}
68+
stmts = [:(Base.@_inline_meta)]
69+
args = []
70+
iter = :(iterate(gen))
71+
inds = CartesianIndices(size_to_tuple(S))
72+
for i in inds
73+
el = Symbol(:el, i)
74+
push!(stmts, :(($el,st) = throw_if_nothing($iter, $inds, $i)))
75+
push!(args, el)
76+
iter = :(iterate(gen,st))
77+
end
78+
push!(stmts, :($iter === nothing || generator_too_long_error($inds)))
79+
push!(stmts, :(SA($(args...))))
80+
Expr(:block, stmts...)
81+
end
82+
"""
83+
sacollect(SA, gen)
84+
85+
Construct a statically-sized vector of type `SA`.from a generator
86+
`gen`. `SA` needs to have a size parameter since the length of `vec`
87+
is unknown to the compiler. `SA` can optionally specify the element
88+
type as well.
89+
90+
Example:
91+
92+
sacollect(SVector{3, Int}, 2i+1 for i in 1:3)
93+
sacollect(SMatrix{2, 3}, i+j for i in 1:2, j in 1:3)
94+
sacollect(SArray{2, 3}, i+j for i in 1:2, j in 1:3)
95+
96+
This creates the same statically-sized vector as if the generator were
97+
collected in an array, but is more efficient since no array is
98+
allocated.
99+
100+
Equivalent:
101+
102+
SVector{3, Int}([2i+1 for i in 1:3])
103+
"""
104+
sacollect
105+
106+
@inline (::Type{SA})(gen::Base.Generator) where {SA <: StaticArray} =
107+
sacollect(SA, gen)
108+
53109
@inline SArray(a::StaticArray) = SArray{size_tuple(Size(a))}(Tuple(a))
54110

55111
####################

src/SMatrix.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ end
5252
end
5353
end
5454

55+
@inline SMatrix{M, N, T}(gen::Base.Generator) where {M, N, T} =
56+
sacollect(SMatrix{M, N, T}, gen)
57+
@inline SMatrix{M, N}(gen::Base.Generator) where {M, N} =
58+
sacollect(SMatrix{M, N}, gen)
59+
5560
@inline convert(::Type{SMatrix{S1,S2}}, a::StaticArray{<:Tuple, T}) where {S1,S2,T} = SMatrix{S1,S2,T}(Tuple(a))
5661
@inline SMatrix(a::StaticMatrix{S1, S2}) where {S1, S2} = SMatrix{S1, S2}(Tuple(a))
5762

src/SVector.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ const SVector{S, T} = SArray{Tuple{S}, T, 1, S}
1919
@inline SVector{S}(x::NTuple{S,T}) where {S, T} = SVector{S,T}(x)
2020
@inline SVector{S}(x::T) where {S, T <: Tuple} = SVector{S,promote_tuple_eltype(T)}(x)
2121

22+
@inline SVector{N, T}(gen::Base.Generator) where {N, T} =
23+
sacollect(SVector{N, T}, gen)
24+
@inline SVector{N}(gen::Base.Generator) where {N} =
25+
sacollect(SVector{N}, gen)
26+
2227
# conversion from AbstractVector / AbstractArray (better inference than default)
2328
#@inline convert{S,T}(::Type{SVector{S}}, a::AbstractArray{T}) = SVector{S,T}((a...))
2429

test/MArray.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,29 @@
3232
v = MArray{Tuple{2}}(1,2)
3333
@test MArray(v) !== v && MArray(v) == v
3434

35+
@test MArray{Tuple{}}(i for i in 1:1).data === (1,)
36+
@test MArray{Tuple{3}}(i for i in 1:3).data === (1,2,3)
37+
@test MArray{Tuple{3}}(float(i) for i in 1:3).data === (1.0,2.0,3.0)
38+
@test MArray{Tuple{2,3}}(i+10j for i in 1:2, j in 1:3).data === (11,12,21,22,31,32)
39+
@test MArray{Tuple{1,2,3}}(i+10j+100k for i in 1:1, j in 1:2, k in 1:3).data === (111,121,211,221,311,321)
40+
@test_throws Exception MArray{Tuple{}}(i for i in 1:0)
41+
@test_throws Exception MArray{Tuple{}}(i for i in 1:2)
42+
@test_throws Exception MArray{Tuple{3}}(i for i in 1:2)
43+
@test_throws Exception MArray{Tuple{3}}(i for i in 1:4)
44+
@test_throws Exception MArray{Tuple{2,3}}(10i+j for i in 1:1, j in 1:3)
45+
@test_throws Exception MArray{Tuple{2,3}}(10i+j for i in 1:3, j in 1:3)
46+
47+
@test StaticArrays.sacollect(MVector{6}, Iterators.product(1:2, 1:3)) ==
48+
MVector{6}(collect(Iterators.product(1:2, 1:3)))
49+
@test StaticArrays.sacollect(MVector{2}, Iterators.zip(1:2, 2:3)) ==
50+
MVector{2}(collect(Iterators.zip(1:2, 2:3)))
51+
@test StaticArrays.sacollect(MVector{3}, Iterators.take(1:10, 3)) ==
52+
MVector{3}(collect(Iterators.take(1:10, 3)))
53+
@test StaticArrays.sacollect(MMatrix{2,3}, Iterators.product(1:2, 1:3)) ==
54+
MMatrix{2,3}(collect(Iterators.product(1:2, 1:3)))
55+
@test StaticArrays.sacollect(MArray{Tuple{2,3,4}}, 1:24) ==
56+
MArray{Tuple{2,3,4}}(collect(1:24))
57+
3558
@test ((@MArray [1])::MArray{Tuple{1}}).data === (1,)
3659
@test ((@MArray [1,2])::MArray{Tuple{2}}).data === (1,2)
3760
@test ((@MArray Float64[1,2,3])::MArray{Tuple{3}}).data === (1.0, 2.0, 3.0)

test/SArray.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,29 @@
2727

2828
@test SArray(SArray{Tuple{2}}(1,2)) === SArray{Tuple{2}}(1,2)
2929

30+
@test SArray{Tuple{}}(i for i in 1:1).data === (1,)
31+
@test SArray{Tuple{3}}(i for i in 1:3).data === (1,2,3)
32+
@test SArray{Tuple{3}}(float(i) for i in 1:3).data === (1.0,2.0,3.0)
33+
@test SArray{Tuple{2,3}}(i+10j for i in 1:2, j in 1:3).data === (11,12,21,22,31,32)
34+
@test SArray{Tuple{1,2,3}}(i+10j+100k for i in 1:1, j in 1:2, k in 1:3).data === (111,121,211,221,311,321)
35+
@test_throws Exception SArray{Tuple{}}(i for i in 1:0)
36+
@test_throws Exception SArray{Tuple{}}(i for i in 1:2)
37+
@test_throws Exception SArray{Tuple{3}}(i for i in 1:2)
38+
@test_throws Exception SArray{Tuple{3}}(i for i in 1:4)
39+
@test_throws Exception SArray{Tuple{2,3}}(10i+j for i in 1:1, j in 1:3)
40+
@test_throws Exception SArray{Tuple{2,3}}(10i+j for i in 1:3, j in 1:3)
41+
42+
@test StaticArrays.sacollect(SVector{6}, Iterators.product(1:2, 1:3)) ==
43+
SVector{6}(collect(Iterators.product(1:2, 1:3)))
44+
@test StaticArrays.sacollect(SVector{2}, Iterators.zip(1:2, 2:3)) ==
45+
SVector{2}(collect(Iterators.zip(1:2, 2:3)))
46+
@test StaticArrays.sacollect(SVector{3}, Iterators.take(1:10, 3)) ==
47+
SVector{3}(collect(Iterators.take(1:10, 3)))
48+
@test StaticArrays.sacollect(SMatrix{2,3}, Iterators.product(1:2, 1:3)) ==
49+
SMatrix{2,3}(collect(Iterators.product(1:2, 1:3)))
50+
@test StaticArrays.sacollect(SArray{Tuple{2,3,4}}, 1:24) ==
51+
SArray{Tuple{2,3,4}}(collect(1:24))
52+
3053
@test ((@SArray [1])::SArray{Tuple{1}}).data === (1,)
3154
@test ((@SArray [1,2])::SArray{Tuple{2}}).data === (1,2)
3255
@test ((@SArray Float64[1,2,3])::SArray{Tuple{3}}).data === (1.0, 2.0, 3.0)

test/SMatrix.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,22 @@
2626
@test SMatrix{2}((1,2,3,4)).data === (1,2,3,4)
2727
@test_throws DimensionMismatch SMatrix{2}((1,2,3,4,5))
2828

29+
@test (SMatrix{2,3}(i+10j for i in 1:2, j in 1:3)::SMatrix{2,3}).data ===
30+
(11,12,21,22,31,32)
31+
@test (SMatrix{2,3}(float(i+10j) for i in 1:2, j in 1:3)::SMatrix{2,3}).data ===
32+
(11.0,12.0,21.0,22.0,31.0,32.0)
33+
@test (SMatrix{0,0,Int}()::SMatrix{0,0}).data === ()
34+
@test (SMatrix{0,3,Int}()::SMatrix{0,3}).data === ()
35+
@test (SMatrix{2,0,Int}()::SMatrix{2,0}).data === ()
36+
@test (SMatrix{2,3,Int}(i+10j for i in 1:2, j in 1:3)::SMatrix{2,3}).data ===
37+
(11,12,21,22,31,32)
38+
@test (SMatrix{2,3,Float64}(i+10j for i in 1:2, j in 1:3)::SMatrix{2,3}).data ===
39+
(11.0,12.0,21.0,22.0,31.0,32.0)
40+
@test_throws Exception SMatrix{2,3}(i+10j for i in 1:1, j in 1:3)
41+
@test_throws Exception SMatrix{2,3}(i+10j for i in 1:3, j in 1:3)
42+
@test_throws Exception SMatrix{2,3,Int}(i+10j for i in 1:1, j in 1:3)
43+
@test_throws Exception SMatrix{2,3,Int}(i+10j for i in 1:3, j in 1:3)
44+
2945
@test ((@SMatrix [1.0])::SMatrix{1,1}).data === (1.0,)
3046
@test ((@SMatrix [1 2])::SMatrix{1,2}).data === (1, 2)
3147
@test ((@SMatrix [1 ; 2])::SMatrix{2,1}).data === (1, 2)

test/SVector.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,16 @@
1616
@test SVector((1,)).data === (1,)
1717
@test SVector((1.0,)).data === (1.0,)
1818

19+
@test SVector{3}(i for i in 1:3).data === (1,2,3)
20+
@test SVector{3}(float(i) for i in 1:3).data === (1.0,2.0,3.0)
21+
@test SVector{0,Int}().data === ()
22+
@test SVector{3,Int}(i for i in 1:3).data === (1,2,3)
23+
@test SVector{3,Float64}(i for i in 1:3).data === (1.0,2.0,3.0)
24+
@test_throws Exception SVector{3}(i for i in 1:2)
25+
@test_throws Exception SVector{3}(i for i in 1:4)
26+
@test_throws Exception SVector{3,Int}(i for i in 1:2)
27+
@test_throws Exception SVector{3,Int}(i for i in 1:4)
28+
1929
@test SVector(1).data === (1,)
2030
@test SVector(1,1.0).data === (1.0,1.0)
2131

0 commit comments

Comments
 (0)