Skip to content

Commit 7b479c4

Browse files
author
Pietro Vertechi
authored
internal refactor (#210)
* internal refactor * cleanup * remove params * extra test
1 parent 275ac0f commit 7b479c4

File tree

6 files changed

+63
-74
lines changed

6 files changed

+63
-74
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StructArrays"
22
uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
3-
version = "0.6.2"
3+
version = "0.6.3"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

docs/src/reference.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ StructArrays.createinstance
5454
```@docs
5555
StructArrays.get_ith
5656
StructArrays.map_params
57-
StructArrays._map_params
5857
StructArrays.buildfromschema
5958
StructArrays.bypass_constructor
6059
StructArrays.iscompatible

src/StructArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module StructArrays
22

3-
using Base: tuple_type_cons, tuple_type_head, tuple_type_tail, tail
3+
using Base: tail
44

55
export StructArray, StructVector, LazyRow, LazyRows
66
export collect_structarray

src/structarray.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,13 @@ function StructArray{T}(c::C) where {T, C<:Tup}
9797
StructArray{T, N, typeof(cols)}(cols)
9898
end
9999

100-
StructArray(c::C) where {C<:NamedTuple} = StructArray{eltypes(C)}(c)
100+
StructArray(c::NamedTuple) = StructArray{eltypes(c)}(c)
101101
StructArray(c::Tuple; names = nothing) = _structarray(c, names)
102102

103103
StructArray{T}(; kwargs...) where {T} = StructArray{T}(values(kwargs))
104104
StructArray(; kwargs...) = StructArray(values(kwargs))
105105

106-
_structarray(args::T, ::Nothing) where {T<:Tuple} = StructArray{eltypes(T)}(args)
106+
_structarray(args::Tuple, ::Nothing) = StructArray{eltypes(args)}(args)
107107
_structarray(args::Tuple, names) = _structarray(args, Tuple(names))
108108
_structarray(args::Tuple, ::Tuple) = _structarray(args, nothing)
109109
_structarray(args::NTuple{N, Any}, names::NTuple{N, Symbol}) where {N} = StructArray(NamedTuple{names}(args))

src/utils.jl

Lines changed: 25 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,39 @@
1-
eltypes(::Type{T}) where {T} = map_params(eltype, T)
1+
argtail(_, args...) = args
22

3-
alwaysfalse(t) = false
4-
5-
"""
6-
StructArrays.map_params(f, T)
7-
8-
Apply `f` to each field type of `Tuple` or `NamedTuple` type `T`, returning a
9-
new `Tuple` or `NamedTuple` type.
3+
split_tuple_type(T) = fieldtype(T, 1), Tuple{argtail(T.parameters...)...}
104

11-
```julia-repl
12-
julia> StructArrays.map_params(T -> Complex{T}, Tuple{Int32,Float64})
13-
Tuple{Complex{Int32},Complex{Float64}}
14-
```
15-
"""
16-
map_params(f, ::Type{NamedTuple{names, types}}) where {names, types} =
17-
NamedTuple{names, map_params(f, types)}
5+
eltypes(nt::NamedTuple{names}) where {names} = NamedTuple{names, eltypes(values(nt))}
6+
eltypes(t::Tuple) = Tuple{map(eltype, t)...}
187

19-
function map_params(f, ::Type{T}) where {T<:Tuple}
20-
if @generated
21-
types = fieldtypes(T)
22-
ex = :(Tuple{})
23-
for t types
24-
push!(ex.args, :(f($t)))
25-
end
26-
ex
27-
else
28-
map_params_fallback(f, T)
29-
end
30-
end
31-
32-
map_params_fallback(f, ::Type{T}) where {T<:Tuple} = Tuple{map(f, fieldtypes(T))...}
8+
alwaysfalse(t) = false
339

3410
"""
35-
StructArrays._map_params(f, T)
11+
StructArrays.map_params(f, T)
3612
3713
Apply `f` to each field type of `Tuple` or `NamedTuple` type `T`, returning a
3814
new `Tuple` or `NamedTuple` object.
3915
4016
```julia-repl
41-
julia> StructArrays._map_params(T -> Complex{T}, Tuple{Int32,Float64})
17+
julia> StructArrays.map_params(T -> Complex{T}, Tuple{Int32,Float64})
4218
(Complex{Int32}, Complex{Float64})
4319
```
4420
"""
45-
_map_params(f::F, ::Type{NamedTuple{names, types}}) where {names, types, F} =
46-
NamedTuple{names}(_map_params(f, types))
21+
map_params(f::F, ::Type{NamedTuple{names, types}}) where {F, names, types} =
22+
NamedTuple{names}(map_params(f, types))
4723

48-
function _map_params(f::F, ::Type{T}) where {T<:Tuple, F}
24+
function map_params(f::F, ::Type{T}) where {F, T<:Tuple}
4925
if @generated
5026
types = fieldtypes(T)
51-
ex = :()
52-
for t types
53-
push!(ex.args, :(f($t)))
54-
end
55-
ex
27+
args = map(t -> :(f($t)), types)
28+
Expr(:tuple, args...)
5629
else
57-
_map_params_fallback(f, T)
30+
map_params_fallback(f, T)
5831
end
5932
end
6033

61-
_map_params_fallback(f, ::Type{T}) where {T<:Tuple} = map(f, fieldtypes(T))
34+
map_params_fallback(f, ::Type{T}) where {T<:Tuple} = map(f, fieldtypes(T))
6235

63-
buildfromschema(initializer::F, ::Type{T}) where {T, F} = buildfromschema(initializer, T, staticschema(T))
36+
buildfromschema(initializer::F, ::Type{T}) where {F, T} = buildfromschema(initializer, T, staticschema(T))
6437

6538
"""
6639
StructArrays.buildfromschema(initializer, T[, S])
@@ -71,8 +44,8 @@ Construct a [`StructArray{T}`](@ref) with a function `initializer`, using a sche
7144
7245
`S` is a `Tuple` or `NamedTuple` type. The default value is [`staticschema(T)`](@ref).
7346
"""
74-
function buildfromschema(initializer::F, ::Type{T}, ::Type{NT}) where {T, NT<:Tup, F}
75-
nt = _map_params(initializer, NT)
47+
function buildfromschema(initializer::F, ::Type{T}, ::Type{NT}) where {F, T, NT<:Tup}
48+
nt = map_params(initializer, NT)
7649
StructArray{T}(nt)
7750
end
7851

@@ -123,16 +96,16 @@ iscompatible(::Type{Tuple{}}, ::Type{T}) where {T<:Tuple} = false
12396
iscompatible(::Type{T}, ::Type{Tuple{}}) where {T<:Tuple} = false
12497
iscompatible(::Type{Tuple{}}, ::Type{Tuple{}}) = true
12598

126-
function iscompatible(::Type{S}, ::Type{T}) where {S<:Tuple, T<:Tuple}
127-
iscompatible(tuple_type_head(S), tuple_type_head(T)) && iscompatible(tuple_type_tail(S), tuple_type_tail(T))
99+
function iscompatible(::Type{T}, ::Type{T′}) where {T<:Tuple, T′<:Tuple}
100+
(f, ls), (f′, ls′) = split_tuple_type(T), split_tuple_type(T′)
101+
iscompatible(f, f′) && iscompatible(ls, ls′)
128102
end
129103

130-
iscompatible(::S, ::T) where {S, T<:AbstractArray} = iscompatible(S, T)
104+
iscompatible(::S, ::V) where {S, V<:AbstractArray} = iscompatible(S, V)
131105

132-
function _promote_typejoin(::Type{S}, ::Type{T}) where {S<:NTuple{N, Any}, T<:NTuple{N, Any}} where N
133-
head = _promote_typejoin(Base.tuple_type_head(S), Base.tuple_type_head(T))
134-
tail = _promote_typejoin(Base.tuple_type_tail(S), Base.tuple_type_tail(T))
135-
return Base.tuple_type_cons(head, tail)
106+
function _promote_typejoin(::Type{T}, ::Type{T′}) where {T<:NTuple{N, Any}, T′<:NTuple{N, Any}} where N
107+
(f, ls), (f′, ls′) = split_tuple_type(T), split_tuple_type(T′)
108+
return Tuple{_promote_typejoin(f, f′), _promote_typejoin(ls, ls′).parameters...}
136109
end
137110

138111
_promote_typejoin(::Type{Tuple{}}, ::Type{Tuple{}}) = Tuple{}
@@ -141,7 +114,7 @@ function _promote_typejoin(::Type{NamedTuple{names, types}}, ::Type{NamedTuple{n
141114
return NamedTuple{names, T}
142115
end
143116

144-
_promote_typejoin(::Type{S}, ::Type{T}) where {S, T} = Base.promote_typejoin(S, T)
117+
_promote_typejoin(::Type{T}, ::Type{T}) where {T, T} = Base.promote_typejoin(T, T)
145118

146119
function _promote_typejoin(::Type{Pair{A, B}}, ::Type{Pair{A′, B′}}) where {A, A′, B, B′}
147120
C = _promote_typejoin(A, A′)

test/runtests.jl

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,7 @@ end
2929

3030
@testset "utils" begin
3131
t = StructArray(rand(ComplexF64, 2, 2))
32-
T = staticschema(eltype(t))
33-
@test StructArrays.eltypes(T) == NamedTuple{(:re, :im), Tuple{Float64, Float64}}
34-
@test StructArrays.map_params(eltype, T) == NamedTuple{(:re, :im), Tuple{Float64, Float64}}
35-
@test StructArrays.map_params(eltype, StructArrays.astuple(T)) == Tuple{Float64, Float64}
32+
@test StructArrays.eltypes((re = 1.0, im = 1.0)) == NamedTuple{(:re, :im), Tuple{Float64, Float64}}
3633
@test !iscompatible(typeof((1, 2)), typeof(([1],)))
3734
@test iscompatible(typeof((1, 2)), typeof(([1], [2])))
3835
@test !iscompatible(typeof((1, 2)), typeof(([1.1], [2])))
@@ -343,6 +340,21 @@ g_infer() = StructArray([(a=(b="1",), c=2)], unwrap = t -> t <: NamedTuple)
343340
tup_infer() = StructArray([(1, 2), (3, 4)])
344341
cols_infer() = StructArray(([1, 2], [1.2, 2.3]))
345342
nt_infer(nt) = StructArray{typeof(nt)}(undef, 4)
343+
eltype_infer() = StructArray((rand(10), rand(Int, 10)))
344+
named_eltype_infer() = StructArray((x=rand(10), y=rand(Int, 10)))
345+
compatible_infer() = Val(iscompatible(Tuple{Int, Int}, Tuple{Vector{Int}, Vector{Real}}))
346+
function promote_infer()
347+
x = (a=1, b=1.2)
348+
y = (a=1.2, b="a")
349+
T = _promote_typejoin(typeof(x), typeof(y))
350+
return convert(T, x)
351+
end
352+
function map_params_infer()
353+
v = StructArray(rand(ComplexF64, 2, 2))
354+
f(T) = similar(v, T)
355+
types = Tuple{Int, Float64, ComplexF32, String}
356+
return StructArrays.map_params(f, types)
357+
end
346358

347359
@testset "inferrability" begin
348360
@inferred f_infer()
@@ -354,6 +366,11 @@ nt_infer(nt) = StructArray{typeof(nt)}(undef, 4)
354366
@test s[2] == (3, 4)
355367
@inferred cols_infer()
356368
@inferred nt_infer((x = 3, y = :a, z = :b))
369+
@inferred eltype_infer()
370+
@inferred named_eltype_infer()
371+
@inferred compatible_infer()
372+
@inferred promote_infer()
373+
@inferred map_params_infer()
357374
end
358375

359376
@testset "propertynames" begin
@@ -882,21 +899,21 @@ end
882899
end
883900
end
884901

885-
# Test fallback (non-@generated) variant of _map_params
886-
@testset "_map_params" begin
902+
# Test fallback (non-@generated) variant of map_params
903+
@testset "map_params" begin
887904
v = StructArray(rand(ComplexF64, 2, 2))
888905
f(T) = similar(v, T)
889906
types = Tuple{Int, Float64, ComplexF32, String}
890-
A = @inferred StructArrays._map_params(f, types)
891-
B = StructArrays._map_params_fallback(f, types)
892-
@test typeof(A) === typeof(B)
893-
end
894-
895-
# Same for map_params
896-
@testset "map_params" begin
897-
types = Tuple{Int, Float64, Int32}
898-
f(T) = Complex{T}
899907
A = @inferred StructArrays.map_params(f, types)
900908
B = StructArrays.map_params_fallback(f, types)
901-
@test A === B
902-
end
909+
@test typeof(A) === typeof(B)
910+
types = Tuple{Int, Float64, ComplexF32}
911+
A = @inferred StructArrays.map_params(zero, types)
912+
B = StructArrays.map_params_fallback(zero, types)
913+
C = map(zero, fieldtypes(types))
914+
@test A === B === C
915+
namedtypes = NamedTuple{(:a, :b, :c), types}
916+
A = @inferred StructArrays.map_params(zero, namedtypes)
917+
C = map(zero, NamedTuple{(:a, :b, :c)}(map(zero, fieldtypes(types))))
918+
@test A === C
919+
end

0 commit comments

Comments
 (0)