Skip to content

Commit 4056c71

Browse files
authored
refactor finding consistent value (#252)
* refactor finding consistent value * add internal docs * add doc compat * remove outdated docstring * fix inferrability on older julia
1 parent 0933432 commit 4056c71

File tree

5 files changed

+36
-17
lines changed

5 files changed

+36
-17
lines changed

docs/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
4+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
45

56
[compat]
7+
Documenter = "0.27"
68
PooledArrays = "1"

docs/src/reference.md

+2
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,6 @@ StructArrays.map_params
5757
StructArrays.buildfromschema
5858
StructArrays.bypass_constructor
5959
StructArrays.iscompatible
60+
StructArrays.maybe_convert_elt
61+
StructArrays.findconsistentvalue
6062
```

src/structarray.jl

+6-11
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,10 @@ struct StructArray{T, N, C<:Tup, I} <: AbstractArray{T, N}
1414
components::C
1515

1616
function StructArray{T, N, C}(c) where {T, N, C<:Tup}
17-
isempty(c) && error("only eltypes with fields are supported")
18-
ax = axes(first(c))
19-
length(ax) == N || error("wrong number of dimensions")
20-
map(tail(c)) do ci
21-
axes(ci) == ax || error("all field arrays must have same shape")
22-
end
17+
isempty(c) && throw(ArgumentError("only eltypes with fields are supported"))
18+
ax = findconsistentvalue(axes, c)
19+
(ax === nothing) && throw(ArgumentError("all component arrays must have the same shape"))
20+
length(ax) == N || throw(ArgumentError("wrong number of dimensions"))
2321
new{T, N, C, index_type(c)}(c)
2422
end
2523
end
@@ -119,9 +117,6 @@ Construct a `StructArray` from slices of `A` along `dims`.
119117
The `unwrap` keyword argument is a function that determines whether to
120118
recursively convert fields of type `FT` to `StructArray`s.
121119
122-
!!! compat "Julia 1.1"
123-
This function requires at least Julia 1.1.
124-
125120
```julia-repl
126121
julia> X = [1.0 2.0; 3.0 4.0]
127122
2×2 Array{Float64,2}:
@@ -369,8 +364,8 @@ end
369364
end
370365

371366
function Base.parentindices(s::StructArray)
372-
res = parentindices(component(s, 1))
373-
all(c -> parentindices(c) == res, components(s)) || throw(ArgumentError("inconsistent parentindices of components"))
367+
res = findconsistentvalue(parentindices, components(s))
368+
(res === nothing) && throw(ArgumentError("inconsistent parentindices of components"))
374369
return res
375370
end
376371

src/utils.jl

+12
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,15 @@ By default, this calls `convert(T, x)`; however, you can specialize it for other
196196
maybe_convert_elt(::Type{T}, vals) where T = convert(T, vals)
197197
maybe_convert_elt(::Type{T}, vals::Tuple) where T = T <: Tuple ? convert(T, vals) : vals # assignment of fields by position
198198
maybe_convert_elt(::Type{T}, vals::NamedTuple) where T = T<:NamedTuple ? convert(T, vals) : vals # assignment of fields by name
199+
200+
"""
201+
findconsistentvalue(f, componenents::Union{Tuple, NamedTuple})
202+
203+
Compute the unique value that `f` takes on each `component ∈ componenents`.
204+
If not all values are equal, return `nothing`. Otherwise, return the unique value.
205+
"""
206+
function findconsistentvalue(f::F, (col, cols...)::Tup) where F
207+
val = f(col)
208+
isconsistent = mapfoldl(isequal(val)f, &, cols; init=true)
209+
return ifelse(isconsistent, val, nothing)
210+
end

test/runtests.jl

+14-6
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,14 @@ end
105105
@test StructArrays.strip_params(Tuple{Int}) == Tuple
106106
@test StructArrays.astuple(NamedTuple{(:a,), Tuple{Float64}}) == Tuple{Float64}
107107
@test StructArrays.strip_params(NamedTuple{(:a,), Tuple{Float64}}) == NamedTuple{(:a,)}
108+
109+
cols = (a=rand(2), b=rand(2), c=rand(2))
110+
@test StructArrays.findconsistentvalue(length, cols) == 2
111+
@test StructArrays.findconsistentvalue(length, Tuple(cols)) == 2
112+
113+
cols = (a=rand(2), b=rand(2), c=rand(3))
114+
@test isnothing(StructArrays.findconsistentvalue(length, cols))
115+
@test isnothing(StructArrays.findconsistentvalue(length, Tuple(cols)))
108116
end
109117

110118
@testset "indexstyle" begin
@@ -439,8 +447,8 @@ end
439447
@test isequal(t.a, [1, missing])
440448
@test eltype(t) <: NamedTuple{(:a,)}
441449

442-
@test_throws ErrorException StructArray([nothing])
443-
@test_throws ErrorException StructArray([1, 2, 3])
450+
@test_throws ArgumentError StructArray([nothing])
451+
@test_throws ArgumentError StructArray([1, 2, 3])
444452
end
445453

446454
@testset "tuple case" begin
@@ -460,10 +468,10 @@ end
460468
@test getproperty(t, 1) == [2]
461469
@test getproperty(t, 2) == [3.0]
462470

463-
@test_throws ErrorException StructArray(([1, 2], [3]))
471+
@test_throws ArgumentError StructArray(([1, 2], [3]))
464472

465-
@test_throws ErrorException StructArray{Tuple{}}(())
466-
@test_throws ErrorException StructArray{Tuple{}, 1, Tuple{}}(())
473+
@test_throws ArgumentError StructArray{Tuple{}}(())
474+
@test_throws ArgumentError StructArray{Tuple{}, 1, Tuple{}}(())
467475
end
468476

469477
@testset "constructor from slices" begin
@@ -503,7 +511,7 @@ end
503511
@test t1 == StructArray((a=[1.2], b=["test"]))
504512
@test t2 == StructArray{Pair{Float64, String}}(([1.2], ["test"]))
505513

506-
@test_throws ErrorException StructArray(a=[1, 2], b=[3])
514+
@test_throws ArgumentError StructArray(a=[1, 2], b=[3])
507515
end
508516

509517
@testset "complex" begin

0 commit comments

Comments
 (0)