28
28
from inspect import isgenerator
29
29
30
30
# require at least one iterable to make this work seamlessly with curry.
31
- def scanl (proc , init , iterable0 , * iterables ):
31
+ def scanl (proc , init , iterable0 , * iterables , longest = False , fillvalue = None ):
32
32
"""Scan (accumulate), optionally with multiple input iterables.
33
33
34
34
Similar to ``itertools.accumulate``. If the inputs are generators, this is
35
35
essentially a lazy ``foldl`` that yields the intermediate result at each step.
36
36
Hence, useful for partially folding infinite sequences.
37
37
38
- At least one iterable (``iterable0``) is required. More are optional.
39
-
40
- Terminates when the shortest input runs out.
41
-
42
38
Initial value is mandatory; there is no sane default for the case with
43
39
multiple inputs.
44
40
41
+ At least one iterable (``iterable0``) is required. More are optional.
42
+
43
+ By default, terminate when the shortest input runs out. To terminate on
44
+ longest input, use ``longest=True`` and optionally provide a ``fillvalue``.
45
+
45
46
Returns a generator, which (roughly, in pseudocode)::
46
47
48
+ z = partial(zip_longest, fillvalue=fillvalue) if longest else zip
47
49
acc = init
48
- for elts in zip (iterable0, *iterables):
50
+ for elts in z (iterable0, *iterables):
49
51
yield proc(*elts, acc) # if this was legal syntax
50
52
"""
51
53
iterables = (iterable0 ,) + iterables
52
- def heads (its ):
53
- hs = []
54
- for it in its :
55
- try :
56
- h = next (it )
57
- except StopIteration : # shortest sequence ran out
54
+ if not longest : # terminate on shortest input
55
+ def heads (its ):
56
+ hs = []
57
+ for it in its :
58
+ try :
59
+ h = next (it )
60
+ except StopIteration : # shortest sequence ran out
61
+ return StopIteration
62
+ hs .append (h )
63
+ return tuple (hs )
64
+ else : # terminate on longest input
65
+ def heads (its ):
66
+ hs = []
67
+ nempty = 0
68
+ for it in its :
69
+ try :
70
+ h = next (it )
71
+ except StopIteration : # this sequence has run out
72
+ h = fillvalue
73
+ nempty += 1 # may legitimately contain None so must count
74
+ hs .append (h )
75
+ if nempty == len (its ):
58
76
return StopIteration
59
- hs .append (h )
60
- return tuple (hs )
77
+ return tuple (hs )
61
78
iters = tuple (iter (x ) for x in iterables )
62
79
acc = init
63
80
while True :
@@ -67,9 +84,10 @@ def heads(its):
67
84
break
68
85
acc = proc (* (hs + (acc ,)))
69
86
70
- def scanr (proc , init , sequence0 , * sequences ):
87
+ def scanr (proc , init , sequence0 , * sequences , longest = False , fillvalue = None ):
71
88
"""Like scanl, but scan from the right (walk each sequence backwards)."""
72
- return scanl (proc , init , reversed (sequence0 ), * (reversed (s ) for s in sequences ))
89
+ return scanl (proc , init , reversed (sequence0 ), * (reversed (s ) for s in sequences ),
90
+ longest = longest , fillvalue = fillvalue )
73
91
74
92
def scanl1 (proc , iterable , init = None ):
75
93
"""scanl for a single iterable, with optional init.
@@ -94,25 +112,28 @@ def scanr1(proc, sequence, init=None):
94
112
"""
95
113
return scanl1 (proc , reversed (sequence ), init )
96
114
97
- def foldl (proc , init , iterable0 , * iterables ):
115
+ def foldl (proc , init , iterable0 , * iterables , longest = False , fillvalue = None ):
98
116
"""Racket-like foldl that supports multiple input iterables.
99
117
100
- At least one iterable (``iterable0``) is required. More are optional.
101
-
102
- Terminates when the shortest input runs out.
103
-
104
118
Initial value is mandatory; there is no sane default for the case with
105
119
multiple inputs.
106
120
121
+ At least one iterable (``iterable0``) is required. More are optional.
122
+
123
+ By default, terminate when the shortest input runs out. To terminate on
124
+ longest input, use ``longest=True`` and optionally provide a ``fillvalue``.
125
+
107
126
Note order: ``proc(elt, acc)``, which is the opposite order of arguments
108
127
compared to ``functools.reduce``. General case ``proc(e1, ..., en, acc)``.
109
128
"""
110
- return last (scanl (proc , init , iterable0 , * iterables ))
129
+ return last (scanl (proc , init , iterable0 , * iterables ,
130
+ longest = longest , fillvalue = fillvalue ))
111
131
112
- def foldr (proc , init , sequence0 , * sequences ):
132
+ def foldr (proc , init , sequence0 , * sequences , longest = False , fillvalue = None ):
113
133
"""Like foldl, but fold from the right (walk each sequence backwards)."""
114
134
# Reverse, then left-fold gives us a linear process.
115
- return foldl (proc , init , reversed (sequence0 ), * (reversed (s ) for s in sequences ))
135
+ return foldl (proc , init , reversed (sequence0 ), * (reversed (s ) for s in sequences ),
136
+ longest = longest , fillvalue = fillvalue )
116
137
117
138
def reducel (proc , iterable , init = None ):
118
139
"""Foldl for a single iterable, with optional init.
@@ -461,7 +482,7 @@ def test():
461
482
from operator import add , mul , itemgetter
462
483
from functools import partial
463
484
from unpythonic .fun import curry , composer , composerc , composel , to1st , rotate , identity
464
- from unpythonic .llist import cons , nil , ll
485
+ from unpythonic .llist import cons , nil , ll , lreverse
465
486
466
487
# scan/accumulate: lazy fold that yields intermediate results.
467
488
assert tuple (scanl (add , 0 , range (1 , 5 ))) == (0 , 1 , 3 , 6 , 10 )
@@ -544,6 +565,14 @@ def mymap_one2(f, sequence):
544
565
myadd = lambda x , y : x + y # can't inspect signature of builtin add
545
566
assert curry (mymap , myadd , ll (1 , 2 , 3 ), ll (2 , 4 , 6 )) == ll (3 , 6 , 9 )
546
567
568
+ # map_longest. foldr would walk the sequences from the right; use foldl.
569
+ mymap_longestrev = lambda f : curry (foldl , composerc (cons , f ), nil , longest = True )
570
+ mymap_longest = composerc (lreverse , mymap_longestrev )
571
+ def noneadd (a , b ):
572
+ if all (x is not None for x in (a , b )):
573
+ return a + b
574
+ assert curry (mymap_longest , noneadd , ll (1 , 2 , 3 ), ll (2 , 4 )) == ll (3 , 6 , None )
575
+
547
576
reverse_one = curry (foldl , cons , nil )
548
577
assert reverse_one (ll (1 , 2 , 3 )) == ll (3 , 2 , 1 )
549
578
@@ -561,6 +590,10 @@ def mymap_one2(f, sequence):
561
590
assert mysum (append_two (a , b )) == 10
562
591
assert myprod (b ) == 12
563
592
593
+ packtwo = lambda a , b : ll (a , b ) # using a tuple return value here would confuse curry.
594
+ assert foldl (composerc (cons , packtwo ), nil , (1 , 2 , 3 ), (4 , 5 ), longest = True ) == \
595
+ ll (ll (3 , None ), ll (2 , 5 ), ll (1 , 4 ))
596
+
564
597
def msqrt (x ): # multivalued sqrt
565
598
if x == 0. :
566
599
return (0. ,)
0 commit comments