@@ -37,21 +37,44 @@ getiterators(h::Hold) = getiterators(h.iterators)
37
37
38
38
Base. length (h:: Hold ) = length (h. iterators)
39
39
40
- check_knownsize (iterators:: Tuple ) = _check_knownsize (first (iterators)) & check_knownsize (Base. tail (iterators))
41
- check_knownsize (:: Tuple{} ) = true
42
- function _check_knownsize (iterator)
40
+ function check_knownsize (iterator)
43
41
itsz = Base. IteratorSize (iterator)
44
42
itsz isa Base. HasLength || itsz isa Base. HasShape
45
43
end
46
44
47
- function zipsplit (iterators:: Tuple , np:: Integer , p:: Integer )
48
- check_knownsize (iterators)
49
- itzip = zip (iterators... )
45
+ struct ZipSplit{Z, I}
46
+ z :: Z
47
+ it :: I
48
+ skip :: Int
49
+ N :: Int
50
+ end
51
+
52
+ # This constructor differs from zipsplit, as it uses skipped and retained elements
53
+ # and not p and np. This type is added to increase compatibility with SplittablesBase
54
+ function ZipSplit (itzip, skipped_elements:: Integer , elements_on_proc:: Integer )
55
+ it = Iterators. take (Iterators. drop (itzip, skipped_elements), elements_on_proc)
56
+ ZipSplit {typeof(itzip), typeof(it)} (itzip, it, skipped_elements, elements_on_proc)
57
+ end
58
+
59
+ Base. length (zs:: ZipSplit ) = length (zs. it)
60
+ Base. eltype (zs:: ZipSplit ) = eltype (zs. it)
61
+ Base. iterate (z:: ZipSplit , i... ) = iterate (takedrop (z), i... )
62
+ takedrop (zs:: ZipSplit ) = zs. it
63
+
64
+ function SplittablesBase. halve (zs:: ZipSplit )
65
+ nleft = zs. N ÷ 2
66
+ ZipSplit (zs. z, zs. skip, nleft), ZipSplit (zs. z, zs. skip + nleft, zs. N - nleft)
67
+ end
68
+
69
+ zipsplit (iterators:: Tuple , np:: Integer , p:: Integer ) = zipsplit (zip (iterators... ), np, p)
70
+
71
+ function zipsplit (itzip:: Iterators.Zip , np:: Integer , p:: Integer )
72
+ check_knownsize (itzip)
50
73
d,r = divrem (length (itzip), np)
51
74
skipped_elements = d* (p- 1 ) + min (r,p- 1 )
52
75
lastind = d* p + min (r,p)
53
76
elements_on_proc = lastind - skipped_elements
54
- Iterators . take (Iterators . drop ( itzip, skipped_elements) , elements_on_proc)
77
+ ZipSplit ( itzip, skipped_elements, elements_on_proc)
55
78
end
56
79
57
80
_split_iterators (iterators, np, p) = (zipsplit (iterators, np, p),)
0 commit comments