Skip to content

Commit b562625

Browse files
committed
parameterize scan and fold so that they can now terminate on the longest input (with optional fillvalue, defaulting to None)
1 parent 241ba11 commit b562625

File tree

2 files changed

+59
-26
lines changed

2 files changed

+59
-26
lines changed

unpythonic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
See ``dir(unpythonic)`` and submodule docstrings for more.
66
"""
77

8-
__version__ = '0.8.3'
8+
__version__ = '0.8.4'
99

1010
from . import rc
1111

unpythonic/it.py

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,36 +28,53 @@
2828
from inspect import isgenerator
2929

3030
# require at least one iterable to make this work seamlessly with curry.
31-
def scanl(proc, init, iterable0, *iterables):
31+
def scanl(proc, init, iterable0, *iterables, longest=False, fillvalue=None):
3232
"""Scan (accumulate), optionally with multiple input iterables.
3333
3434
Similar to ``itertools.accumulate``. If the inputs are generators, this is
3535
essentially a lazy ``foldl`` that yields the intermediate result at each step.
3636
Hence, useful for partially folding infinite sequences.
3737
38-
At least one iterable (``iterable0``) is required. More are optional.
39-
40-
Terminates when the shortest input runs out.
41-
4238
Initial value is mandatory; there is no sane default for the case with
4339
multiple inputs.
4440
41+
At least one iterable (``iterable0``) is required. More are optional.
42+
43+
By default, terminate when the shortest input runs out. To terminate on
44+
longest input, use ``longest=True`` and optionally provide a ``fillvalue``.
45+
4546
Returns a generator, which (roughly, in pseudocode)::
4647
48+
z = partial(zip_longest, fillvalue=fillvalue) if longest else zip
4749
acc = init
48-
for elts in zip(iterable0, *iterables):
50+
for elts in z(iterable0, *iterables):
4951
yield proc(*elts, acc) # if this was legal syntax
5052
"""
5153
iterables = (iterable0,) + iterables
52-
def heads(its):
53-
hs = []
54-
for it in its:
55-
try:
56-
h = next(it)
57-
except StopIteration: # shortest sequence ran out
54+
if not longest: # terminate on shortest input
55+
def heads(its):
56+
hs = []
57+
for it in its:
58+
try:
59+
h = next(it)
60+
except StopIteration: # shortest sequence ran out
61+
return StopIteration
62+
hs.append(h)
63+
return tuple(hs)
64+
else: # terminate on longest input
65+
def heads(its):
66+
hs = []
67+
nempty = 0
68+
for it in its:
69+
try:
70+
h = next(it)
71+
except StopIteration: # this sequence has run out
72+
h = fillvalue
73+
nempty += 1 # may legitimately contain None so must count
74+
hs.append(h)
75+
if nempty == len(its):
5876
return StopIteration
59-
hs.append(h)
60-
return tuple(hs)
77+
return tuple(hs)
6178
iters = tuple(iter(x) for x in iterables)
6279
acc = init
6380
while True:
@@ -67,9 +84,10 @@ def heads(its):
6784
break
6885
acc = proc(*(hs + (acc,)))
6986

70-
def scanr(proc, init, sequence0, *sequences):
87+
def scanr(proc, init, sequence0, *sequences, longest=False, fillvalue=None):
7188
"""Like scanl, but scan from the right (walk each sequence backwards)."""
72-
return scanl(proc, init, reversed(sequence0), *(reversed(s) for s in sequences))
89+
return scanl(proc, init, reversed(sequence0), *(reversed(s) for s in sequences),
90+
longest=longest, fillvalue=fillvalue)
7391

7492
def scanl1(proc, iterable, init=None):
7593
"""scanl for a single iterable, with optional init.
@@ -94,25 +112,28 @@ def scanr1(proc, sequence, init=None):
94112
"""
95113
return scanl1(proc, reversed(sequence), init)
96114

97-
def foldl(proc, init, iterable0, *iterables):
115+
def foldl(proc, init, iterable0, *iterables, longest=False, fillvalue=None):
98116
"""Racket-like foldl that supports multiple input iterables.
99117
100-
At least one iterable (``iterable0``) is required. More are optional.
101-
102-
Terminates when the shortest input runs out.
103-
104118
Initial value is mandatory; there is no sane default for the case with
105119
multiple inputs.
106120
121+
At least one iterable (``iterable0``) is required. More are optional.
122+
123+
By default, terminate when the shortest input runs out. To terminate on
124+
longest input, use ``longest=True`` and optionally provide a ``fillvalue``.
125+
107126
Note order: ``proc(elt, acc)``, which is the opposite order of arguments
108127
compared to ``functools.reduce``. General case ``proc(e1, ..., en, acc)``.
109128
"""
110-
return last(scanl(proc, init, iterable0, *iterables))
129+
return last(scanl(proc, init, iterable0, *iterables,
130+
longest=longest, fillvalue=fillvalue))
111131

112-
def foldr(proc, init, sequence0, *sequences):
132+
def foldr(proc, init, sequence0, *sequences, longest=False, fillvalue=None):
113133
"""Like foldl, but fold from the right (walk each sequence backwards)."""
114134
# Reverse, then left-fold gives us a linear process.
115-
return foldl(proc, init, reversed(sequence0), *(reversed(s) for s in sequences))
135+
return foldl(proc, init, reversed(sequence0), *(reversed(s) for s in sequences),
136+
longest=longest, fillvalue=fillvalue)
116137

117138
def reducel(proc, iterable, init=None):
118139
"""Foldl for a single iterable, with optional init.
@@ -461,7 +482,7 @@ def test():
461482
from operator import add, mul, itemgetter
462483
from functools import partial
463484
from unpythonic.fun import curry, composer, composerc, composel, to1st, rotate, identity
464-
from unpythonic.llist import cons, nil, ll
485+
from unpythonic.llist import cons, nil, ll, lreverse
465486

466487
# scan/accumulate: lazy fold that yields intermediate results.
467488
assert tuple(scanl(add, 0, range(1, 5))) == (0, 1, 3, 6, 10)
@@ -544,6 +565,14 @@ def mymap_one2(f, sequence):
544565
myadd = lambda x, y: x + y # can't inspect signature of builtin add
545566
assert curry(mymap, myadd, ll(1, 2, 3), ll(2, 4, 6)) == ll(3, 6, 9)
546567

568+
# map_longest. foldr would walk the sequences from the right; use foldl.
569+
mymap_longestrev = lambda f: curry(foldl, composerc(cons, f), nil, longest=True)
570+
mymap_longest = composerc(lreverse, mymap_longestrev)
571+
def noneadd(a, b):
572+
if all(x is not None for x in (a, b)):
573+
return a + b
574+
assert curry(mymap_longest, noneadd, ll(1, 2, 3), ll(2, 4)) == ll(3, 6, None)
575+
547576
reverse_one = curry(foldl, cons, nil)
548577
assert reverse_one(ll(1, 2, 3)) == ll(3, 2, 1)
549578

@@ -561,6 +590,10 @@ def mymap_one2(f, sequence):
561590
assert mysum(append_two(a, b)) == 10
562591
assert myprod(b) == 12
563592

593+
packtwo = lambda a, b: ll(a, b) # using a tuple return value here would confuse curry.
594+
assert foldl(composerc(cons, packtwo), nil, (1, 2, 3), (4, 5), longest=True) == \
595+
ll(ll(3, None), ll(2, 5), ll(1, 4))
596+
564597
def msqrt(x): # multivalued sqrt
565598
if x == 0.:
566599
return (0.,)

0 commit comments

Comments
 (0)