diff --git a/docs/src/api.md b/docs/src/api.md index edd8be32..f5e9fcb0 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -25,6 +25,7 @@ Optimisers.ClipGrad Optimisers.ClipNorm Optimisers.WeightDecay Optimisers.OptimiserChain +Optimisers.Lookahead ``` ## Model Interface diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 417b90d4..6f3063da 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -11,7 +11,7 @@ export destructure, total, total2 include("rules.jl") export Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, OADAM, AdaBelief, - WeightDecay, ClipGrad, ClipNorm, OptimiserChain + WeightDecay, ClipGrad, ClipNorm, OptimiserChain, Lookahead """ Optimisers.apply!(rule::RuleType, state, parameters, gradient) -> (state, gradient) diff --git a/src/rules.jl b/src/rules.jl index 65392a0f..c929251d 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -555,3 +555,71 @@ function Base.show(io::IO, c::OptimiserChain) join(io, c.opts, ", ") print(io, ")") end + +""" + Lookahead(opt = Momentum(), k = 10, α = 0.5) + +The [Lookahead](https://arxiv.org/abs/1907.08610) optimiser +keeps a "slow" copy of the parameters, `y`. +Most steps update only the "fast" parameters `x` as usual, +using the given `opt`, but every `k`th step updates the slow +parameters, and then resets the fast ones to match: + + @. y = α * x + (1-α) * y + @. x = y + +# Parameters +- Optimiser (`opt`): Used at each fast step. +- Synchronization period (`k`): Number of fast steps between slow steps. +- Slow weight step size (`α`): Proportion of the change to fast weights which is kept. + +# Example +```jldoctest +julia> x = [10.0]; o = Lookahead(Descent(0.5), 4); s = Optimisers.setup(o, x); + +julia> for t in 1:8 + s, x = Optimisers.update!(s, x, [1]) + y = s.state[3] + @show x, y + end +(x, y) = ([9.5], [10.0]) +(x, y) = ([9.0], [10.0]) +(x, y) = ([8.5], [10.0]) +(x, y) = ([9.0], [9.0]) +(x, y) = ([8.5], [9.0]) +(x, y) = ([8.0], [9.0]) +(x, y) = ([7.5], [9.0]) +(x, y) = ([8.0], [8.0]) +``` +""" +struct Lookahead{O, T<:Real} + inner::O + k::Int + alpha::T +end +Lookahead(opt = Momentum(), k::Int = 10, α = 5f-1) = Lookahead{typeof(opt),typeof(α)}(Momentum(), k, α) + +init(o::Lookahead, x::AbstractArray) = (init(o.inner, x), 1, copy(x)) + +update!(ℓ::Leaf{<:Lookahead}, x, ::Zero, ::Zero...) = ℓ, x +function update!(ℓ::Leaf{<:Lookahead}, x, x̄s...) + instate, n, y = ℓ.state + is′, x̄′ = apply!(ℓ.rule.inner, instate, x, base.(x̄s)...) + + if n % (ℓ.rule.k) != 0 + return Leaf(ℓ.rule, (is′, n+1, y)), subtract!(x, x̄′) + else + α = ℓ.rule.alpha + @.. y = α * (x - x̄′) + (1 - α) * y + @.. x = y + return Leaf(ℓ.rule, (is′, n+1, y)), x + end +end + +function Base.show(io::IO, o::Lookahead) + print(io, "Lookahead(") + show(io, o.inner) + print(io, ", ", o.k, ", ") + show(io, o.alpha) + print(io, ")") +end diff --git a/test/rules.jl b/test/rules.jl index ffb4ca65..cf37f596 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -14,10 +14,13 @@ RULES = [ OptimiserChain(ClipNorm(), ADAM(0.001)), OptimiserChain(ClipGrad(0.5), Momentum()), OptimiserChain(WeightDecay(), OADAM(), ClipGrad(1)), + # Lookahead + Lookahead(), Lookahead(ADAMW(0.003), 5, 0.7) ] name(o) = typeof(o).name.name # just for printing testset headings name(o::OptimiserChain) = join(name.(o.opts), " → ") +name(o::Lookahead) = string("LookAhead(", name(o.inner), ")") LOG = Dict() # for debugging these testsets, this makes it easy to plot each optimiser's loss