Skip to content

Commit 878ce7c

Browse files
author
Pietro Vertechi
committed
add GroupJoinPerm
1 parent e48271b commit 878ce7c

File tree

3 files changed

+72
-0
lines changed

3 files changed

+72
-0
lines changed

src/StructArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ include("structarray.jl")
1111
include("utils.jl")
1212
include("collect.jl")
1313
include("sort.jl")
14+
include("groupjoin.jl")
1415
include("lazy.jl")
1516

1617
function __init__()

src/groupjoin.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
@generated function rowcmp(c::StructVector, i, d::StructVector{C, D}, j) where {C, D}
2+
N = fieldcount(D)
3+
ex = :(cmp(getfield(fieldarrays(c),$N)[i], getfield(fieldarrays(d),$N)[j]))
4+
for n in N-1:-1:1
5+
ex = quote
6+
let k = rowcmp(getfield(fieldarrays(c),$n), i, getfield(fieldarrays(d),$n), j)
7+
(k == 0) ? ($ex) : k
8+
end
9+
end
10+
end
11+
ex
12+
end
13+
14+
@inline function rowcmp(c::AbstractVector, i, d::AbstractVector, j)
15+
cmp(c[i], d[j])
16+
end
17+
18+
struct GroupJoinPerm{LP<:GroupPerm, RP<:GroupPerm}
19+
left::LP
20+
right::RP
21+
end
22+
23+
GroupJoinPerm(lkeys::AbstractVector, rkeys::AbstractVector, lperm=sortperm(lkeys), rperm=sortperm(rkeys)) =
24+
GroupJoinPerm(GroupPerm(lkeys, lperm), GroupPerm(rkeys, rperm))
25+
26+
function _pick(s, a, b)
27+
if a === nothing && b === nothing
28+
return nothing
29+
elseif a === nothing
30+
return (1:0, b[1]), (1, a, b)
31+
elseif b === nothing
32+
return (a[1], 1:0), (-1, a, b)
33+
else
34+
lp = sortperm(s.left)
35+
rp = sortperm(s.right)
36+
cmp = rowcmp(parent(s.left), lp[first(a[1])], parent(s.right), rp[first(b[1])])
37+
if cmp < 0
38+
return (a[1], 1:0), (-1, a, b)
39+
elseif cmp == 0
40+
return (a[1], b[1]), (0, a, b)
41+
else
42+
return (1:0, b[1]), (1, a, b)
43+
end
44+
end
45+
end
46+
47+
function Base.iterate(s::GroupJoinPerm)
48+
l = iterate(s.left)
49+
r = iterate(s.right)
50+
_pick(s, l, r)
51+
end
52+
53+
function Base.iterate(s::GroupJoinPerm, (select, l, r))
54+
(select <= 0) && (l = iterate(s.left, l[2]))
55+
(select >= 0) && (r = iterate(s.right, r[2]))
56+
_pick(s, l, r)
57+
end
58+
59+
Base.IteratorSize(::Type{<:GroupJoinPerm}) = Base.SizeUnknown()
60+
Base.IteratorEltype(::Type{<:GroupJoinPerm}) = Base.HasEltype()
61+
Base.eltype(::Type{<:GroupJoinPerm}) = Tuple{UnitRange{Int}, UnitRange{Int}}

test/runtests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,16 @@ end
109109
@test u == [1, 2, 3]
110110
end
111111

112+
@testset "groupjoin" begin
113+
a = [1, 2, 1, 1, 0, 9, -100]
114+
b = [-2, 12, 1, 1, 0, 11, 9]
115+
itr = StructArrays.GroupJoinPerm(a, b)
116+
s = StructArray(itr)
117+
as, bs = fieldarrays(s)
118+
@test as == [1:1, 1:0, 2:2, 3:5, 6:6, 7:7, 1:0, 1:0]
119+
@test bs == [1:0, 1:1, 2:2, 3:4, 1:0, 5:5, 6:6, 7:7]
120+
end
121+
112122
@testset "similar" begin
113123
t = StructArray(a = rand(10), b = rand(Bool, 10))
114124
s = similar(t)

0 commit comments

Comments
 (0)