@@ -465,7 +465,7 @@ function rrule(
465
465
y = first (last (hobbits))
466
466
project = ProjectTo (x)
467
467
function foldl_pullback_tuple (dy)
468
- trio = accumulate (_reverse1 (hobbits); init= (0 , dy, 0 )) do (_, dc, _), (_, back)
468
+ trio = accumulate (reverse (hobbits); init= (0 , dy, 0 )) do (_, dc, _), (_, back)
469
469
ds, da, db = back (dc)
470
470
# Don't need to store every `da`, need one for the next iteration + the last.
471
471
end
@@ -501,78 +501,43 @@ end
501
501
502
502
# The implementation was originally for both tuples and arrays, although using accumulate
503
503
# to carry intermediate results along creates arrays of tuples which could be avoided.
504
- # Using a loop can be a few times faster, this should be replaced.
505
- # Note also that it does not return a gradient for `init`.
504
+ # Using a loop can be a few times faster, this should be replaced:
505
+ # https://github.com/FluxML/Zygote.jl/issues/644#issuecomment-628762305
506
+
507
+ # Note also that it does not return a gradient for `init`, now marked `@not_implemented`.
506
508
507
509
function rrule (
508
- config:: RuleConfig{>:HasReverseMode} , :: typeof (Base. mapfoldl_impl), :: typeof (identity), op:: G , init, x:: Union{AbstractArray, Tuple} ;
510
+ config:: RuleConfig{>:HasReverseMode} , :: typeof (Base. mapfoldl_impl), :: typeof (identity), op:: G , init, x:: Union{AbstractArray, Tuple} ;
509
511
) where {G}
510
- list, start = if init === _INIT
511
- _drop1 (x), first (x)
512
+ start, list = if init === Base . _InitialValue ()
513
+ Iterators . peel (x)
512
514
else
513
515
# Case with init keyword is simpler to understand first!
514
- _reshape1 (x, :), init # (vec is for Julia 1.0, accumulate is fussy)
516
+ init, x
515
517
end
516
- hobbits = accumulate (list; init= (start, nothing )) do (a,_), b
517
- # Here `a` is what we would normally cary forward, and `_` ignores
518
- # the previous iteration's pullback function (needed later),
519
- # while `b` is the fresh input from `list` as usual.
520
- c, back = rrule_via_ad (config, op, a, b) # LHS is just documentation here!
521
- # We don't really need to store every `c`, last one is `foldl` output.
522
- # (The name, BTW, is because "there and back again" is the subtitle of Tolkien's book.)
518
+ hobbits = accumulate (list; init= (start, nothing )) do (a, _), b
519
+ c, back = rrule_via_ad (config, op, a, b)
523
520
end
524
521
y = first (last (hobbits))
525
522
axe = axes (x)
526
523
project = ProjectTo (x)
527
524
function unfoldl (dy)
528
- trio = accumulate (_reverse1 (hobbits); init= (0 , dy, 0 )) do (_, dc, _), (_, back)
525
+ trio = accumulate (Iterators . reverse (hobbits); init= (0 , dy, 0 )) do (_, dc, _), (_, back)
529
526
ds, da, db = back (dc)
530
- # Don't need to store every `da`, need one for the next iteration + maybe last
531
527
end
532
528
dop = sum (first, trio)
533
- dx = map (last, _reverse1 (trio))
534
- if init === _INIT
535
- # `hobbits` is one short
529
+ dx = map (last, Iterators. reverse (trio))
530
+ if init === Base. _InitialValue () # `hobbits` is one short
536
531
dx = _vcat1 (trio[end ][2 ], dx)
537
532
end
538
533
d_init = @not_implemented " gradient for foldl does not at present include init, sorry"
539
- return (NoTangent (), NoTangent (), dop, d_init, project (_reshape1 (dx, axe)))
534
+ return (NoTangent (), NoTangent (), dop, d_init, project (reshape (dx, axe)))
540
535
end
541
536
return y, unfoldl
542
537
end
543
538
544
-
545
- # ####
546
- # #### Iterator-or-Tuple functions
547
- # ####
548
-
549
- # This zoo of underscore functions helps `foldl` & `accumulate` handle both tuples and arrays,
550
- # and also provides some alternatives for versions of Julia where iterators weren't supported.
551
- # Inspired by `Base._reverse`, used in defn of `foldr`.
552
-
553
- # To support 2nd derivatives, some may need their own gradient rules. And _drop1 should perhaps
554
- # be replaced by _peel1 like Iterators.peel
555
-
556
- _reverse1 (x) = Iterators. reverse (x)
557
- _drop1 (x) = Iterators. drop (x, 1 )
558
- _zip2 (x, y) = zip (x, y) # for `accumulate`, below
559
-
560
- _reverse1 (x:: Tuple ) = reverse (x)
561
- _drop1 (x:: Tuple ) = Base. tail (x)
562
- _zip2 (x:: Tuple{Vararg{Any,N}} , y:: Tuple{Vararg{Any,N}} ) where N = ntuple (i -> (x[i],y[i]), N)
563
-
564
- const _INIT = Base. _InitialValue ()
565
-
566
539
_vcat1 (x, ys:: AbstractVector ) = vcat (x, ys)
567
540
_vcat1 (x:: AbstractArray , ys:: AbstractVector ) = vcat ([x], ys)
568
- _vcat1 (x, ys:: Tuple ) = (x, ys... )
569
-
570
- _reshape1 (x:: AbstractArray , axe) = reshape (x, axe)
571
- _reshape1 (x:: Tuple , axe) = x
572
-
573
- _no_tuple_tangent (dx:: Tangent ) = ChainRulesCore. backing (dx)
574
- _no_tuple_tangent (dx) = dx
575
-
576
541
577
542
# ####
578
543
# #### `accumulate`
@@ -584,13 +549,18 @@ _no_tuple_tangent(dx) = dx
584
549
# Move it down to: `_accumulate!(op, B, A::AbstractVector, dims::Nothing, init::Nothing)`
585
550
586
551
function rrule (
587
- config:: RuleConfig{>:HasReverseMode} , :: typeof (Base. _accumulate!), op:: G , y, x:: AbstractVector , dims:: Nothing , init,
552
+ config:: RuleConfig{>:HasReverseMode} ,
553
+ :: typeof (Base. _accumulate!),
554
+ op:: G , y:: AbstractVector ,
555
+ x:: AbstractVector ,
556
+ dims:: Nothing ,
557
+ init,
588
558
) where {G}
589
559
590
- list, start = if init === nothing
591
- _drop1 (x), first (x)
560
+ start, list = if init === nothing
561
+ Iterators . peel (x)
592
562
else
593
- x, something (init)
563
+ something (init), x
594
564
end
595
565
hobbits = accumulate (list; init = (start, nothing )) do (a, _), b
596
566
c, back = rrule_via_ad (config, op, a, b)
@@ -607,28 +577,24 @@ function rrule(
607
577
axe = axes (x)
608
578
project = ProjectTo (x)
609
579
function decumulate (dy)
610
- dy_plain = _no_tuple_tangent (unthunk (dy))
611
- rev_list = if init === nothing
612
- # Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...))
613
- # gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
614
- _zip2 (_reverse1 (hobbits), _reverse1 (dy_plain))
615
- else
616
- _zip2 (_reverse1 (hobbits), _reverse1 (dy_plain))
617
- end
580
+ dy_plain = unthunk (dy)
581
+ rev_list = zip (Iterators. reverse (hobbits), Iterators. reverse (dy_plain))
582
+ # Here we rely on `zip` to stop early when init === nothing. Begin explicit with Iterators.reverse(Iterators.drop(..., 1))
583
+ # gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
618
584
trio = accumulate (rev_list; init= (0 , ZeroTangent (), 0 )) do (_, dc, _), ((_, back), dz)
619
585
ds, da, db = back (dc + dz)
620
586
# Don't need to store every 'da', but need for next iteration, and the last one.
621
587
end
622
588
dop = sum (first, trio)
623
- dx = map (last, _reverse1 (trio))
589
+ dx = map (last, Iterators . reverse (trio))
624
590
if init == nothing
625
591
# `hobbits` is one short, and the first one is weird
626
592
dx = _vcat1 (trio[end ][2 ] + dy_plain[1 ], dx)
627
593
end
628
594
dy = @not_implemented " no gradient for `B` in `accumulate!(f, B, A)`, the rule intends to support `accumulate` only"
629
595
d_init_not = @not_implemented " gradient for accumulate does not at present include init, sorry"
630
596
d_init = init === nothing ? NoTangent () : Tangent {typeof(init)} (; value = d_init_not)
631
- return (NoTangent (), dop, dy, project (_reshape1 (dx, axe)), NoTangent (), d_init)
597
+ return (NoTangent (), dop, dy, project (reshape (dx, axe)), NoTangent (), d_init)
632
598
end
633
- return _reshape1 (y, axe), decumulate
599
+ return reshape (y, axe), decumulate
634
600
end
0 commit comments