Skip to content

Commit 5962572

Browse files
authored
switch to GroupPerm for efficient sortperm (#59)
1 parent bb17d3d commit 5962572

File tree

3 files changed

+53
-86
lines changed

3 files changed

+53
-86
lines changed

src/StructArrays.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ function __init__()
1717
Requires.@require Tables="bd369af6-aec1-5ad0-b16a-f7cc5008161c" include("tables.jl")
1818
Requires.@require WeakRefStrings="ea10d353-3f73-51f8-a26c-33c1cb351aa5" begin
1919
fastpermute!(v::WeakRefStrings.StringArray, p::AbstractVector) = permute!(v, p)
20+
@inline function roweq(a::WeakRefStrings.StringArray{String}, i, j)
21+
weaksa = convert(WeakRefStrings.StringArray{WeakRefStrings.WeakRefString{UInt8}}, a)
22+
@inbounds isequal(weaksa[i], weaksa[j])
23+
end
2024
end
2125
end
2226

src/sort.jl

Lines changed: 34 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -4,99 +4,61 @@ fastpermute!(v::AbstractArray, p::AbstractVector) = copyto!(v, v[p])
44
fastpermute!(v::StructArray, p::AbstractVector) = permute!(v, p)
55
fastpermute!(v::PooledArray, p::AbstractVector) = permute!(v, p)
66

7-
optimize_isequal(v::AbstractArray) = v
8-
optimize_isequal(v::PooledArray) = v.refs
9-
optimize_isequal(v::StructArray{<:Union{Tuple, NamedTuple}}) = StructArray(map(optimize_isequal, fieldarrays(v)))
10-
11-
recover_original(v::AbstractArray, el) = el
12-
recover_original(v::PooledArray, el) = v.pool[el]
13-
recover_original(v::StructArray{T}, el) where {T<:Union{Tuple, NamedTuple}} = T(map(recover_original, fieldarrays(v), el))
14-
15-
pool(v::AbstractArray, condition = !isbitstypeeltype) = condition(v) ? convert(PooledArray, v) : v
16-
pool(v::StructArray, condition = !isbitstypeeltype) = replace_storage(t -> pool(t, condition), v)
17-
187
function Base.permute!(c::StructArray, p::AbstractVector)
198
foreachfield(v -> fastpermute!(v, p), c)
209
return c
2110
end
2211

23-
struct TiedIndices{T<:AbstractVector, V<:AbstractVector{<:Integer}, U<:AbstractUnitRange}
24-
vec::T
25-
perm::V
12+
pool(v::AbstractArray, condition = !isbitstypeeltype) = condition(v) ? convert(PooledArray, v) : v
13+
pool(v::StructArray, condition = !isbitstypeeltype) = replace_storage(t -> pool(t, condition), v)
14+
15+
struct GroupPerm{V<:AbstractVector, P<:AbstractVector{<:Integer}, U<:AbstractUnitRange}
16+
vec::V
17+
perm::P
2618
within::U
2719
end
2820

29-
TiedIndices(vec::AbstractVector, perm=sortperm(vec)) =
30-
TiedIndices(vec, perm, axes(vec, 1))
31-
32-
Base.IteratorSize(::Type{<:TiedIndices}) = Base.SizeUnknown()
21+
GroupPerm(vec, perm=sortperm(vec)) = GroupPerm(vec, perm, axes(vec, 1))
3322

34-
Base.eltype(::Type{<:TiedIndices{T}}) where {T} =
35-
Pair{eltype(T), UnitRange{Int}}
23+
Base.sortperm(g::GroupPerm) = g.perm
3624

37-
Base.sortperm(t::TiedIndices) = t.perm
38-
39-
function Base.iterate(n::TiedIndices, i = first(n.within))
40-
vec, perm = n.vec, n.perm
41-
l = last(n.within)
25+
function Base.iterate(g::GroupPerm, i = first(g.within))
26+
vec, perm = g.vec, g.perm
27+
l = last(g.within)
4228
i > l && return nothing
43-
@inbounds row = vec[perm[i]]
29+
@inbounds pi = perm[i]
4430
i1 = i+1
45-
@inbounds while i1 <= l && isequal(row, vec[perm[i1]])
31+
@inbounds while i1 <= l && roweq(vec, pi, perm[i1])
4632
i1 += 1
4733
end
48-
return (row => i:(i1-1), i1)
34+
return (i:(i1-1), i1)
4935
end
5036

51-
"""
52-
`tiedindices(v, perm=sortperm(v))`
53-
54-
Given an abstract vector `v` and a permutation vector `perm`, return an iterator
55-
of pairs `val => range` where `range` is a maximal interval such as `v[perm[range]]`
56-
is constant: `val` is the unique value of `v[perm[range]]`.
57-
"""
58-
tiedindices(v, perm=sortperm(v)) = TiedIndices(v, perm)
59-
60-
"""
61-
`maptiedindices(f, v, perm)`
62-
63-
Given a function `f`, compute the iterator `tiedindices(v, perm)` and return
64-
in iterable object which yields `f(val, idxs)` where `val => idxs` are the pairs
65-
iterated by `tiedindices(v, perm)`.
66-
67-
## Examples
68-
69-
`maptiedindices` is a low level building block that can be used to define grouping
70-
operators. For example:
71-
72-
```jldoctest
73-
julia> function mygroupby(f, keys, data)
74-
perm = sortperm(keys)
75-
StructArrays.maptiedindices(keys, perm) do key, idxs
76-
key => f(data[perm[idxs]])
77-
end
78-
end
79-
mygroupby (generic function with 1 method)
80-
81-
julia> StructArray(mygroupby(sum, [1, 2, 1, 3], [1, 4, 10, 11]))
82-
3-element StructArray{Pair{Int64,Int64},1,NamedTuple{(:first, :second),Tuple{Array{Int64,1},Array{Int64,1}}}}:
83-
1 => 11
84-
2 => 4
85-
3 => 11
86-
```
87-
"""
88-
function maptiedindices(f, v, perm)
89-
fast_v = optimize_isequal(v)
90-
itr = TiedIndices(fast_v, perm)
91-
(f(recover_original(v, val), idxs) for (val, idxs) in itr)
37+
Base.IteratorSize(::Type{<:GroupPerm}) = Base.SizeUnknown()
38+
39+
Base.eltype(::Type{<:GroupPerm}) = UnitRange{Int}
40+
41+
@inline roweq(x::AbstractVector, i, j) = (@inbounds eq=isequal(x[i], x[j]); eq)
42+
@inline roweq(a::PooledArray, i, j) = (@inbounds x=a.refs[i] == a.refs[j]; x)
43+
@generated function roweq(c::StructVector{D,C}, i, j) where {D,C}
44+
N = fieldcount(C)
45+
ex = :(roweq(getfield(fieldarrays(c),1), i, j))
46+
for n in 2:N
47+
ex = :(($ex) && (roweq(getfield(fieldarrays(c),$n), i, j)))
48+
end
49+
ex
9250
end
9351

9452
function uniquesorted(keys, perm=sortperm(keys))
95-
maptiedindices((key, _) -> key, keys, perm)
53+
(keys[perm[idxs[1]]] for idxs in GroupPerm(keys, perm))
9654
end
9755

9856
function finduniquesorted(keys, perm=sortperm(keys))
99-
maptiedindices((key, idxs) -> (key => perm[idxs]), keys, perm)
57+
func = function (idxs)
58+
p_idxs = perm[idxs]
59+
return keys[p_idxs[1]] => p_idxs
60+
end
61+
(func(idxs) for idxs in GroupPerm(keys, perm))
10062
end
10163

10264
function Base.sortperm(c::StructVector{T}) where {T<:Union{Tuple, NamedTuple}}
@@ -126,7 +88,7 @@ function refine_perm!(p, cols, c, x, y′, lo, hi)
12688
order = Perm(Forward, y′)
12789
y = something(forward_vec(order), y′)
12890
nc = length(cols)
129-
for (_, idxs) in TiedIndices(optimize_isequal(x), p, lo:hi)
91+
for idxs in GroupPerm(x, p, lo:hi)
13092
i, i1 = extrema(idxs)
13193
if i1 > i
13294
sort_sub_by!(p, i, i1, y, order, temp)

test/runtests.jl

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,19 @@ end
3030
@test v_pooled == StructArrays.pool(v)
3131
end
3232

33-
@testset "optimize_isequal" begin
33+
@testset "roweq" begin
3434
a = ["a", "b", "a", "a"]
3535
b = PooledArrays.PooledArray(["x", "y", "z", "x"])
3636
s = StructArray((a, b))
37-
t = StructArrays.optimize_isequal(s)
38-
@test t[1] != t[2]
39-
@test t[1] != t[3]
40-
@test t[1] == t[4]
41-
@test t[1][2] isa Integer
42-
@test StructArrays.recover_original(s, t[1]) == s[1]
43-
@test StructArrays.recover_original(s, t[2]) == s[2]
44-
@test StructArrays.recover_original(s, t[3]) == s[3]
45-
@test StructArrays.recover_original(s, t[4]) == s[4]
37+
@test StructArrays.roweq(s, 1, 1)
38+
@test !StructArrays.roweq(s, 1, 2)
39+
@test !StructArrays.roweq(s, 1, 3)
40+
@test StructArrays.roweq(s, 1, 4)
41+
strs = WeakRefStrings.StringArray(["a", "a", "b"])
42+
@test StructArrays.roweq(strs, 1, 1)
43+
@test StructArrays.roweq(strs, 1, 2)
44+
@test !StructArrays.roweq(strs, 1, 3)
45+
@test !StructArrays.roweq(strs, 2, 3)
4646
end
4747

4848
@testset "namedtuple" begin
@@ -95,11 +95,12 @@ end
9595

9696
@testset "iterators" begin
9797
c = [1, 2, 3, 1, 1]
98-
d = StructArrays.tiedindices(c)
99-
@test eltype(d) == Pair{Int, UnitRange{Int}}
98+
d = StructArrays.GroupPerm(c)
99+
@test eltype(d) == UnitRange{Int}
100+
@test Base.IteratorEltype(d) == Base.HasEltype()
101+
@test sortperm(d) == sortperm(c)
100102
s = collect(d)
101-
@test first.(s) == [1, 2, 3]
102-
@test last.(s) == [1:3, 4:4, 5:5]
103+
@test s == [1:3, 4:4, 5:5]
103104
t = collect(StructArrays.finduniquesorted(c))
104105
@test first.(t) == [1, 2, 3]
105106
@test last.(t) == [[1, 4, 5], [2], [3]]

0 commit comments

Comments
 (0)