Skip to content

Commit 413b017

Browse files
committed
tidy up
1 parent 9af7a64 commit 413b017

File tree

1 file changed

+31
-65
lines changed

1 file changed

+31
-65
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 31 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ function rrule(
465465
y = first(last(hobbits))
466466
project = ProjectTo(x)
467467
function foldl_pullback_tuple(dy)
468-
trio = accumulate(_reverse1(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
468+
trio = accumulate(reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
469469
ds, da, db = back(dc)
470470
# Don't need to store every `da`, need one for the next iteration + the last.
471471
end
@@ -501,78 +501,43 @@ end
501501

502502
# The implementation was originally for both tuples and arrays, although using accumulate
503503
# to carry intermediate results along creates arrays of tuples which could be avoided.
504-
# Using a loop can be a few times faster, this should be replaced.
505-
# Note also that it does not return a gradient for `init`.
504+
# Using a loop can be a few times faster, this should be replaced:
505+
# https://github.com/FluxML/Zygote.jl/issues/644#issuecomment-628762305
506+
507+
# Note also that it does not return a gradient for `init`, now marked `@not_implemented`.
506508

507509
function rrule(
508-
config::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::G, init, x::Union{AbstractArray, Tuple};
510+
config::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::G, init, x::Union{AbstractArray, Tuple};
509511
) where {G}
510-
list, start = if init === _INIT
511-
_drop1(x), first(x)
512+
start, list = if init === Base._InitialValue()
513+
Iterators.peel(x)
512514
else
513515
# Case with init keyword is simpler to understand first!
514-
_reshape1(x, :), init # (vec is for Julia 1.0, accumulate is fussy)
516+
init, x
515517
end
516-
hobbits = accumulate(list; init=(start, nothing)) do (a,_), b
517-
# Here `a` is what we would normally cary forward, and `_` ignores
518-
# the previous iteration's pullback function (needed later),
519-
# while `b` is the fresh input from `list` as usual.
520-
c, back = rrule_via_ad(config, op, a, b) # LHS is just documentation here!
521-
# We don't really need to store every `c`, last one is `foldl` output.
522-
# (The name, BTW, is because "there and back again" is the subtitle of Tolkien's book.)
518+
hobbits = accumulate(list; init=(start, nothing)) do (a, _), b
519+
c, back = rrule_via_ad(config, op, a, b)
523520
end
524521
y = first(last(hobbits))
525522
axe = axes(x)
526523
project = ProjectTo(x)
527524
function unfoldl(dy)
528-
trio = accumulate(_reverse1(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
525+
trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
529526
ds, da, db = back(dc)
530-
# Don't need to store every `da`, need one for the next iteration + maybe last
531527
end
532528
dop = sum(first, trio)
533-
dx = map(last, _reverse1(trio))
534-
if init === _INIT
535-
# `hobbits` is one short
529+
dx = map(last, Iterators.reverse(trio))
530+
if init === Base._InitialValue() # `hobbits` is one short
536531
dx = _vcat1(trio[end][2], dx)
537532
end
538533
d_init = @not_implemented "gradient for foldl does not at present include init, sorry"
539-
return (NoTangent(), NoTangent(), dop, d_init, project(_reshape1(dx, axe)))
534+
return (NoTangent(), NoTangent(), dop, d_init, project(reshape(dx, axe)))
540535
end
541536
return y, unfoldl
542537
end
543538

544-
545-
#####
546-
##### Iterator-or-Tuple functions
547-
#####
548-
549-
# This zoo of underscore functions helps `foldl` & `accumulate` handle both tuples and arrays,
550-
# and also provides some alternatives for versions of Julia where iterators weren't supported.
551-
# Inspired by `Base._reverse`, used in defn of `foldr`.
552-
553-
# To support 2nd derivatives, some may need their own gradient rules. And _drop1 should perhaps
554-
# be replaced by _peel1 like Iterators.peel
555-
556-
_reverse1(x) = Iterators.reverse(x)
557-
_drop1(x) = Iterators.drop(x, 1)
558-
_zip2(x, y) = zip(x, y) # for `accumulate`, below
559-
560-
_reverse1(x::Tuple) = reverse(x)
561-
_drop1(x::Tuple) = Base.tail(x)
562-
_zip2(x::Tuple{Vararg{Any,N}}, y::Tuple{Vararg{Any,N}}) where N = ntuple(i -> (x[i],y[i]), N)
563-
564-
const _INIT = Base._InitialValue()
565-
566539
_vcat1(x, ys::AbstractVector) = vcat(x, ys)
567540
_vcat1(x::AbstractArray, ys::AbstractVector) = vcat([x], ys)
568-
_vcat1(x, ys::Tuple) = (x, ys...)
569-
570-
_reshape1(x::AbstractArray, axe) = reshape(x, axe)
571-
_reshape1(x::Tuple, axe) = x
572-
573-
_no_tuple_tangent(dx::Tangent) = ChainRulesCore.backing(dx)
574-
_no_tuple_tangent(dx) = dx
575-
576541

577542
#####
578543
##### `accumulate`
@@ -584,13 +549,18 @@ _no_tuple_tangent(dx) = dx
584549
# Move it down to: `_accumulate!(op, B, A::AbstractVector, dims::Nothing, init::Nothing)`
585550

586551
function rrule(
587-
config::RuleConfig{>:HasReverseMode}, ::typeof(Base._accumulate!), op::G, y, x::AbstractVector, dims::Nothing, init,
552+
config::RuleConfig{>:HasReverseMode},
553+
::typeof(Base._accumulate!),
554+
op::G, y::AbstractVector,
555+
x::AbstractVector,
556+
dims::Nothing,
557+
init,
588558
) where {G}
589559

590-
list, start = if init === nothing
591-
_drop1(x), first(x)
560+
start, list = if init === nothing
561+
Iterators.peel(x)
592562
else
593-
x, something(init)
563+
something(init), x
594564
end
595565
hobbits = accumulate(list; init = (start, nothing)) do (a, _), b
596566
c, back = rrule_via_ad(config, op, a, b)
@@ -607,28 +577,24 @@ function rrule(
607577
axe = axes(x)
608578
project = ProjectTo(x)
609579
function decumulate(dy)
610-
dy_plain = _no_tuple_tangent(unthunk(dy))
611-
rev_list = if init === nothing
612-
# Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...))
613-
# gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
614-
_zip2(_reverse1(hobbits), _reverse1(dy_plain))
615-
else
616-
_zip2(_reverse1(hobbits), _reverse1(dy_plain))
617-
end
580+
dy_plain = unthunk(dy)
581+
rev_list = zip(Iterators.reverse(hobbits), Iterators.reverse(dy_plain))
582+
# Here we rely on `zip` to stop early when init === nothing. Begin explicit with Iterators.reverse(Iterators.drop(..., 1))
583+
# gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
618584
trio = accumulate(rev_list; init=(0, ZeroTangent(), 0)) do (_, dc, _), ((_, back), dz)
619585
ds, da, db = back(dc + dz)
620586
# Don't need to store every 'da', but need for next iteration, and the last one.
621587
end
622588
dop = sum(first, trio)
623-
dx = map(last, _reverse1(trio))
589+
dx = map(last, Iterators.reverse(trio))
624590
if init == nothing
625591
# `hobbits` is one short, and the first one is weird
626592
dx = _vcat1(trio[end][2] + dy_plain[1], dx)
627593
end
628594
dy = @not_implemented "no gradient for `B` in `accumulate!(f, B, A)`, the rule intends to support `accumulate` only"
629595
d_init_not = @not_implemented "gradient for accumulate does not at present include init, sorry"
630596
d_init = init === nothing ? NoTangent() : Tangent{typeof(init)}(; value = d_init_not)
631-
return (NoTangent(), dop, dy, project(_reshape1(dx, axe)), NoTangent(), d_init)
597+
return (NoTangent(), dop, dy, project(reshape(dx, axe)), NoTangent(), d_init)
632598
end
633-
return _reshape1(y, axe), decumulate
599+
return reshape(y, axe), decumulate
634600
end

0 commit comments

Comments
 (0)