Skip to content

Commit dd29059

Browse files
authored
Compatibility with SplittablesBase (#14)
* compatibility with SplittablesBase * ignore ambiguities on old versions * version bump to 0.8.6 * remove reduntant methods
1 parent aa838b3 commit dd29059

7 files changed

+79
-10
lines changed

Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
name = "ParallelUtilities"
22
uuid = "fad6cfc8-4f83-11e9-06cc-151124046ad0"
33
authors = ["Jishnu Bhattacharya"]
4-
version = "0.8.5"
4+
version = "0.8.6"
55

66
[deps]
77
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
88
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
9+
SplittablesBase = "171d559e-b47b-412a-8079-5efa626c420e"
910

1011
[compat]
1112
DataStructures = "0.17, 0.18"
13+
SplittablesBase = "0.1"
1214
julia = "1.2"
1315

1416
[extras]

src/ParallelUtilities.jl

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module ParallelUtilities
22

33
using Distributed
4+
using SplittablesBase
45

56
export pmapreduce
67
export pmapreduce_productsplit

src/mapreduce.jl

+30-7
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,44 @@ getiterators(h::Hold) = getiterators(h.iterators)
3737

3838
Base.length(h::Hold) = length(h.iterators)
3939

40-
check_knownsize(iterators::Tuple) = _check_knownsize(first(iterators)) & check_knownsize(Base.tail(iterators))
41-
check_knownsize(::Tuple{}) = true
42-
function _check_knownsize(iterator)
40+
function check_knownsize(iterator)
4341
itsz = Base.IteratorSize(iterator)
4442
itsz isa Base.HasLength || itsz isa Base.HasShape
4543
end
4644

47-
function zipsplit(iterators::Tuple, np::Integer, p::Integer)
48-
check_knownsize(iterators)
49-
itzip = zip(iterators...)
45+
struct ZipSplit{Z, I}
46+
z :: Z
47+
it :: I
48+
skip :: Int
49+
N :: Int
50+
end
51+
52+
# This constructor differs from zipsplit, as it uses skipped and retained elements
53+
# and not p and np. This type is added to increase compatibility with SplittablesBase
54+
function ZipSplit(itzip, skipped_elements::Integer, elements_on_proc::Integer)
55+
it = Iterators.take(Iterators.drop(itzip, skipped_elements), elements_on_proc)
56+
ZipSplit{typeof(itzip), typeof(it)}(itzip, it, skipped_elements, elements_on_proc)
57+
end
58+
59+
Base.length(zs::ZipSplit) = length(zs.it)
60+
Base.eltype(zs::ZipSplit) = eltype(zs.it)
61+
Base.iterate(z::ZipSplit, i...) = iterate(takedrop(z), i...)
62+
takedrop(zs::ZipSplit) = zs.it
63+
64+
function SplittablesBase.halve(zs::ZipSplit)
65+
nleft = zs.N ÷ 2
66+
ZipSplit(zs.z, zs.skip, nleft), ZipSplit(zs.z, zs.skip + nleft, zs.N - nleft)
67+
end
68+
69+
zipsplit(iterators::Tuple, np::Integer, p::Integer) = zipsplit(zip(iterators...), np, p)
70+
71+
function zipsplit(itzip::Iterators.Zip, np::Integer, p::Integer)
72+
check_knownsize(itzip)
5073
d,r = divrem(length(itzip), np)
5174
skipped_elements = d*(p-1) + min(r,p-1)
5275
lastind = d*p + min(r,p)
5376
elements_on_proc = lastind - skipped_elements
54-
Iterators.take(Iterators.drop(itzip, skipped_elements), elements_on_proc)
77+
ZipSplit(itzip, skipped_elements, elements_on_proc)
5578
end
5679

5780
_split_iterators(iterators, np, p) = (zipsplit(iterators, np, p),)

src/productsplit.jl

+15
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,21 @@ Base.lastindex(ps::AbstractConstrainedProduct) = lastindexglobal(ps) - firstinde
217217
firstindexglobal(ps::AbstractConstrainedProduct) = ProductSection(ps).firstind
218218
lastindexglobal(ps::AbstractConstrainedProduct) = ProductSection(ps).lastind
219219

220+
# SplittablesBase interface
221+
function SplittablesBase.halve(ps::AbstractConstrainedProduct)
222+
iter = getiterators(ps)
223+
firstind = firstindexglobal(ps)
224+
lastind = lastindexglobal(ps)
225+
nleft = length(ps) ÷ 2
226+
firstindleft = firstind
227+
lastindleft = firstind + nleft - 1
228+
firstindright = lastindleft + 1
229+
lastindright = lastind
230+
tl = togglelevels(ps)
231+
ProductSection(iter, tl, firstindleft, lastindleft),
232+
ProductSection(iter, tl, firstindright, lastindright)
233+
end
234+
220235
"""
221236
childindex(ps::AbstractConstrainedProduct, ind)
222237

test/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
55
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
66
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
77
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
8+
SplittablesBase = "171d559e-b47b-412a-8079-5efa626c420e"
89
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
910

1011
[compat]

test/misctests_singleprocess.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@ ProductSplit, SegmentedOrderedBinaryTree
99
import ParallelUtilities.ClusterQueryUtils: chooseworkers
1010

1111
@testset "Project quality" begin
12-
Aqua.test_all(ParallelUtilities)
12+
if VERSION < v"1.6.0"
13+
Aqua.test_all(ParallelUtilities, ambiguities=false)
14+
else
15+
Aqua.test_all(ParallelUtilities)
16+
end
1317
end
1418

1519
DocMeta.setdocmeta!(ParallelUtilities, :DocTestSetup, :(using ParallelUtilities); recursive=true)

test/productsplit.jl

+24-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
using Distributed
22
using Test
33
using ParallelUtilities
4-
import ParallelUtilities: ProductSplit, ProductSection,
4+
import ParallelUtilities: ProductSplit, ProductSection, ZipSplit, zipsplit,
55
minimumelement, maximumelement, extremaelement, nelements, dropleading, indexinproduct,
66
extremadims, localindex, extrema_commonlastdim, whichproc, procrange_recast, whichproc_localindex,
77
getiterators, _niterators
8+
using SplittablesBase
89

910
macro testsetwithinfo(str, ex)
1011
quote
@@ -423,6 +424,17 @@ end
423424
@test nelements(ps, dims = 3) == 1
424425
end
425426

427+
@testset "SplittablesBase" begin
428+
for iters in [(1:4, 1:3), (1:4, 1:4)]
429+
for ps = Any[ProductSplit(iters, 3, 2), ProductSection(iters, 3:8)]
430+
l, r = SplittablesBase.halve(ps)
431+
lc, rc = SplittablesBase.halve(collect(ps))
432+
@test collect(l) == lc
433+
@test collect(r) == rc
434+
end
435+
end
436+
end
437+
426438
@test ParallelUtilities._checknorollover((), (), ())
427439
end;
428440

@@ -453,3 +465,14 @@ end;
453465
@test a <= b
454466
end
455467
end;
468+
469+
@testset "ZipSplit" begin
470+
@testset "SplittablesBase" begin
471+
for ps in [zipsplit((1:4, 1:4), 3, 2), zipsplit((1:5, 1:5), 3, 2)]
472+
l, r = SplittablesBase.halve(ps)
473+
lc, rc = SplittablesBase.halve(collect(ps))
474+
@test collect(l) == lc
475+
@test collect(r) == rc
476+
end
477+
end
478+
end

0 commit comments

Comments
 (0)