Skip to content

Commit dc2a29b

Browse files
authored
add maptiedindices (#57)
1 parent 3e2ee9d commit dc2a29b

File tree

2 files changed

+56
-11
lines changed

2 files changed

+56
-11
lines changed

src/sort.jl

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ optimize_isequal(v::AbstractArray) = v
88
optimize_isequal(v::PooledArray) = v.refs
99
optimize_isequal(v::StructArray{<:Union{Tuple, NamedTuple}}) = StructArray(map(optimize_isequal, fieldarrays(v)))
1010

11+
recover_original(v::AbstractArray, el) = el
12+
recover_original(v::PooledArray, el) = v.pool[el]
13+
recover_original(v::StructArray{T}, el) where {T<:Union{Tuple, NamedTuple}} = T(map(recover_original, fieldarrays(v), el))
14+
1115
pool(v::AbstractArray, condition = !isbitstypeeltype) = condition(v) ? convert(PooledArray, v) : v
1216
pool(v::StructArray, condition = !isbitstypeeltype) = replace_storage(t -> pool(t, condition), v)
1317

@@ -16,9 +20,9 @@ function Base.permute!(c::StructArray, p::AbstractVector)
1620
return c
1721
end
1822

19-
struct TiedIndices{T<:AbstractVector, I<:Integer, U<:AbstractUnitRange}
23+
struct TiedIndices{T<:AbstractVector, V<:AbstractVector{<:Integer}, U<:AbstractUnitRange}
2024
vec::T
21-
perm::Vector{I}
25+
perm::V
2226
within::U
2327
end
2428

@@ -44,17 +48,55 @@ function Base.iterate(n::TiedIndices, i = first(n.within))
4448
return (row => i:(i1-1), i1)
4549
end
4650

47-
tiedindices(args...) = TiedIndices(args...)
51+
"""
52+
`tiedindices(v, perm=sortperm(v))`
53+
54+
Given an abstract vector `v` and a permutation vector `perm`, return an iterator
55+
of pairs `val => range` where `range` is a maximal interval such as `v[perm[range]]`
56+
is constant: `val` is the unique value of `v[perm[range]]`.
57+
"""
58+
tiedindices(v, perm=sortperm(v)) = TiedIndices(v, perm)
59+
60+
"""
61+
`maptiedindices(f, v, perm)`
62+
63+
Given a function `f`, compute the iterator `tiedindices(v, perm)` and return
64+
in iterable object which yields `f(val, idxs)` where `val => idxs` are the pairs
65+
iterated by `tiedindices(v, perm)`.
66+
67+
## Examples
68+
69+
`maptiedindices` is a low level building block that can be used to define grouping
70+
operators. For example:
71+
72+
```jldoctest
73+
julia> function mygroupby(f, keys, data)
74+
perm = sortperm(keys)
75+
StructArrays.maptiedindices(keys, perm) do key, idxs
76+
key => f(data[perm[idxs]])
77+
end
78+
end
79+
mygroupby (generic function with 1 method)
80+
81+
julia> StructArray(mygroupby(sum, [1, 2, 1, 3], [1, 4, 10, 11]))
82+
3-element StructArray{Pair{Int64,Int64},1,NamedTuple{(:first, :second),Tuple{Array{Int64,1},Array{Int64,1}}}}:
83+
1 => 11
84+
2 => 4
85+
3 => 11
86+
```
87+
"""
88+
function maptiedindices(f, v, perm)
89+
fast_v = optimize_isequal(v)
90+
itr = TiedIndices(fast_v, perm)
91+
(f(recover_original(v, val), idxs) for (val, idxs) in itr)
92+
end
4893

49-
function uniquesorted(args...)
50-
t = tiedindices(args...)
51-
(row for (row, _) in t)
94+
function uniquesorted(keys, perm=sortperm(keys))
95+
maptiedindices((key, _) -> key, keys, perm)
5296
end
5397

54-
function finduniquesorted(args...)
55-
t = tiedindices(args...)
56-
p = sortperm(t)
57-
(row => p[idxs] for (row, idxs) in t)
98+
function finduniquesorted(keys, perm=sortperm(keys))
99+
maptiedindices((key, idxs) -> (key => perm[idxs]), keys, perm)
58100
end
59101

60102
function Base.sortperm(c::StructVector{T}) where {T<:Union{Tuple, NamedTuple}}
@@ -148,4 +190,3 @@ function sort_int_range_sub_by!(x, ioffs, n, by, rangelen, minval, temp)
148190
end
149191
x
150192
end
151-

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ end
3939
@test t[1] != t[3]
4040
@test t[1] == t[4]
4141
@test t[1][2] isa Integer
42+
@test StructArrays.recover_original(s, t[1]) == s[1]
43+
@test StructArrays.recover_original(s, t[2]) == s[2]
44+
@test StructArrays.recover_original(s, t[3]) == s[3]
45+
@test StructArrays.recover_original(s, t[4]) == s[4]
4246
end
4347

4448
@testset "namedtuple" begin

0 commit comments

Comments
 (0)