diff --git a/docs/src/api.md b/docs/src/api.md index 6c021f25..efc886e9 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -17,6 +17,7 @@ Optimisers.AMSGrad Optimisers.NAdam Optimisers.AdamW Optimisers.AdaBelief +Optimisers.PAdam ``` In addition to the main course, you may wish to order some of these condiments: diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 20fc8aad..da5df27d 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -13,7 +13,7 @@ export destructure include("rules.jl") export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, - AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, + AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, PAdam, WeightDecay, ClipGrad, ClipNorm, OptimiserChain, Lion, AccumGrad diff --git a/src/rules.jl b/src/rules.jl index e994b740..74036a1b 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -543,7 +543,41 @@ function apply!(o::AdaBelief, state, x, dx) return (mt, st, βt .* β), dx′ end +""" +PAdam(η = 1f-2, β = (9f-1, 9.99f-1), ρ = 2.5f-1, eps(typeof(η))) + +The partially adaptive momentum estimation method (PADAM) [https://arxiv.org/pdf/1806.06763v1.pdf] +# Parameters +- Learning rate (`η`): Amount by which gradients are discounted before updating + the weights. +- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the + second (β2) momentum estimate. +- Partially adaptive parameter (`p`): Varies between 0 and 0.5. +- Machine epsilon (`ϵ`): Constant to prevent division by zero + (no need to change default) +""" +struct PAdam{T} <: AbstractRule + eta::T + beta::Tuple{T, T} + rho::T + epsilon::T +end +PAdam(η = 1f-2, β = (9f-1, 9.99f-1), ρ = 2.5f-1, ϵ = eps(typeof(η))) = PAdam{typeof(η)}(η, β, ρ, ϵ) + +init(o::PAdam, x::AbstractArray) = (onevalue(o.epsilon, x), onevalue(o.epsilon, x), onevalue(o.epsilon, x)) + +function apply!(o::PAdam, state, x, dx) + η, β, ρ, ϵ = o.eta, o.beta, o.rho, o.epsilon + mt, vt, v̂t = state + + @.. mt = β[1] * mt + (1 - β[1]) * dx + @.. vt = β[2] * vt + (1 - β[2]) * abs2(dx) + @.. v̂t = max(v̂t, vt) + dx′ = @lazy η * mt / (v̂t ^ ρ + ϵ) + + return (mt, vt, v̂t), dx′ +end """ WeightDecay(γ = 5f-4) diff --git a/test/rules.jl b/test/rules.jl index a10e055f..a4d3b8f6 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -8,7 +8,7 @@ RULES = [ # All the rules at default settings: Descent(), Adam(), Momentum(), Nesterov(), Rprop(), RMSProp(), AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(), - AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(), + AdamW(), RAdam(), OAdam(), AdaBelief(), PAdam(), Lion(), # A few chained combinations: OptimiserChain(WeightDecay(), Adam(0.001)), OptimiserChain(ClipNorm(), Adam(0.001)), @@ -266,4 +266,4 @@ end tree, x4 = Optimisers.update(tree, x3, g4) @test x4 ≈ x3 -end \ No newline at end of file +end