From f901e02790dc7c24a02d2af771f74b67cb700615 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 7 Mar 2022 00:35:31 -0500 Subject: [PATCH 1/2] add lookahead optimiser --- docs/src/api.md | 1 + src/Optimisers.jl | 2 +- src/rules.jl | 52 +++++++++++++++++++++++++++++++++++++++++++++++ test/rules.jl | 3 +++ 4 files changed, 57 insertions(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index edd8be32..f5e9fcb0 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -25,6 +25,7 @@ Optimisers.ClipGrad Optimisers.ClipNorm Optimisers.WeightDecay Optimisers.OptimiserChain +Optimisers.Lookahead ``` ## Model Interface diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 417b90d4..6f3063da 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -11,7 +11,7 @@ export destructure, total, total2 include("rules.jl") export Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, OADAM, AdaBelief, - WeightDecay, ClipGrad, ClipNorm, OptimiserChain + WeightDecay, ClipGrad, ClipNorm, OptimiserChain, Lookahead """ Optimisers.apply!(rule::RuleType, state, parameters, gradient) -> (state, gradient) diff --git a/src/rules.jl b/src/rules.jl index 65392a0f..eb182aa0 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -555,3 +555,55 @@ function Base.show(io::IO, c::OptimiserChain) join(io, c.opts, ", ") print(io, ")") end + +""" + Lookahead(α = 0.5, k = 10, opt = Momentum()) + + +The [Lookahead](https://arxiv.org/abs/1907.08610) optimiser +keeps a "slow" copy of the parameters, initially `y = copy(x)`. +Most steps update only the "fast" parameters `x` as usual, +using the given `opt`, but every `k`th step updates the slow +parameters, and then resets the fast ones to match: + + @. y = α * x + (1-α) * y + @. x = y + +# Parameters +- Slow weight step size (`α`): Proportion of the change to fast weights which is kept. +- Synchronization period (`k`): Number of fast steps between slow steps. +- Optimiser (`opt`): Used at each fast step. +""" +struct Lookahead{T<:Real, O} + alpha::T + steps::Int + inner::O +end +Lookahead(α::Real = 5f-1, k::Int = 10, opt = Momentum()) = Lookahead{typeof(α), typeof(opt)}(α, k, opt) + +init(o::Lookahead, x::AbstractArray) = (copy(x), 1, init(o.inner, x)) + +update!(ℓ::Leaf{<:Lookahead}, x, ::Zero, ::Zero...) = ℓ, x +function update!(ℓ::Leaf{<:Lookahead}, x, x̄s...) + α = ℓ.rule.alpha + k = ℓ.rule.steps + y, n, instate = ℓ.state + + is′, x̄′ = apply!(ℓ.rule.inner, instate, x, base.(x̄s)...) + + if n % k != 0 + return Leaf(ℓ.rule, (y, n + 1, is′)), subtract!(x, x̄′) + else + @.. y = α * (x - x̄′) + (1 - α) * y + @.. x = y + return Leaf(ℓ.rule, (y, n + 1, is′)), x + end +end + +function Base.show(io::IO, o::Lookahead) + print(io, "Lookahead(") + show(io, o.alpha) + print(io, ", ", o.steps, ", ") + show(io, o.inner) + print(io, ")") +end diff --git a/test/rules.jl b/test/rules.jl index ffb4ca65..d5aaa97e 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -14,10 +14,13 @@ RULES = [ OptimiserChain(ClipNorm(), ADAM(0.001)), OptimiserChain(ClipGrad(0.5), Momentum()), OptimiserChain(WeightDecay(), OADAM(), ClipGrad(1)), + # Lookahead + Lookahead(), Lookahead(0.5, 5, ADAMW(0.001)) ] name(o) = typeof(o).name.name # just for printing testset headings name(o::OptimiserChain) = join(name.(o.opts), " → ") +name(o::Lookahead) = string("LookAhead(", name(o.inner), ")") LOG = Dict() # for debugging these testsets, this makes it easy to plot each optimiser's loss From ebe294cdf2bf95566550c37ba0f6ba3408d117ed Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 7 Mar 2022 10:04:55 -0500 Subject: [PATCH 2/2] tidy, add example --- src/rules.jl | 54 +++++++++++++++++++++++++++++++++------------------ test/rules.jl | 2 +- 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index eb182aa0..c929251d 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -557,11 +557,10 @@ function Base.show(io::IO, c::OptimiserChain) end """ - Lookahead(α = 0.5, k = 10, opt = Momentum()) - + Lookahead(opt = Momentum(), k = 10, α = 0.5) The [Lookahead](https://arxiv.org/abs/1907.08610) optimiser -keeps a "slow" copy of the parameters, initially `y = copy(x)`. +keeps a "slow" copy of the parameters, `y`. Most steps update only the "fast" parameters `x` as usual, using the given `opt`, but every `k`th step updates the slow parameters, and then resets the fast ones to match: @@ -570,40 +569,57 @@ parameters, and then resets the fast ones to match: @. x = y # Parameters -- Slow weight step size (`α`): Proportion of the change to fast weights which is kept. -- Synchronization period (`k`): Number of fast steps between slow steps. - Optimiser (`opt`): Used at each fast step. +- Synchronization period (`k`): Number of fast steps between slow steps. +- Slow weight step size (`α`): Proportion of the change to fast weights which is kept. + +# Example +```jldoctest +julia> x = [10.0]; o = Lookahead(Descent(0.5), 4); s = Optimisers.setup(o, x); + +julia> for t in 1:8 + s, x = Optimisers.update!(s, x, [1]) + y = s.state[3] + @show x, y + end +(x, y) = ([9.5], [10.0]) +(x, y) = ([9.0], [10.0]) +(x, y) = ([8.5], [10.0]) +(x, y) = ([9.0], [9.0]) +(x, y) = ([8.5], [9.0]) +(x, y) = ([8.0], [9.0]) +(x, y) = ([7.5], [9.0]) +(x, y) = ([8.0], [8.0]) +``` """ -struct Lookahead{T<:Real, O} - alpha::T - steps::Int +struct Lookahead{O, T<:Real} inner::O + k::Int + alpha::T end -Lookahead(α::Real = 5f-1, k::Int = 10, opt = Momentum()) = Lookahead{typeof(α), typeof(opt)}(α, k, opt) +Lookahead(opt = Momentum(), k::Int = 10, α = 5f-1) = Lookahead{typeof(opt),typeof(α)}(Momentum(), k, α) -init(o::Lookahead, x::AbstractArray) = (copy(x), 1, init(o.inner, x)) +init(o::Lookahead, x::AbstractArray) = (init(o.inner, x), 1, copy(x)) update!(ℓ::Leaf{<:Lookahead}, x, ::Zero, ::Zero...) = ℓ, x function update!(ℓ::Leaf{<:Lookahead}, x, x̄s...) - α = ℓ.rule.alpha - k = ℓ.rule.steps - y, n, instate = ℓ.state - + instate, n, y = ℓ.state is′, x̄′ = apply!(ℓ.rule.inner, instate, x, base.(x̄s)...) - if n % k != 0 - return Leaf(ℓ.rule, (y, n + 1, is′)), subtract!(x, x̄′) + if n % (ℓ.rule.k) != 0 + return Leaf(ℓ.rule, (is′, n+1, y)), subtract!(x, x̄′) else + α = ℓ.rule.alpha @.. y = α * (x - x̄′) + (1 - α) * y @.. x = y - return Leaf(ℓ.rule, (y, n + 1, is′)), x + return Leaf(ℓ.rule, (is′, n+1, y)), x end end function Base.show(io::IO, o::Lookahead) print(io, "Lookahead(") - show(io, o.alpha) - print(io, ", ", o.steps, ", ") show(io, o.inner) + print(io, ", ", o.k, ", ") + show(io, o.alpha) print(io, ")") end diff --git a/test/rules.jl b/test/rules.jl index d5aaa97e..cf37f596 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -15,7 +15,7 @@ RULES = [ OptimiserChain(ClipGrad(0.5), Momentum()), OptimiserChain(WeightDecay(), OADAM(), ClipGrad(1)), # Lookahead - Lookahead(), Lookahead(0.5, 5, ADAMW(0.001)) + Lookahead(), Lookahead(ADAMW(0.003), 5, 0.7) ] name(o) = typeof(o).name.name # just for printing testset headings