diff --git a/src/rules.jl b/src/rules.jl index 82237eba..1f57f246 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -544,6 +544,42 @@ function apply!(o::AdaBelief, state, x, dx) return (mt, st, βt .* β), dx′ end +""" + PAdam(η = 1f-2, β = (9f-1, 9.99f-1), ρ = 2.5f-1, eps(typeof(η))) + +The partially adaptive momentum estimation method (PADAM) [https://arxiv.org/pdf/1806.06763v1.pdf] + +# Parameters +- Learning rate (`η`): Amount by which gradients are discounted before updating + the weights. +- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the + second (β2) momentum estimate. +- Partially adaptive parameter (`p`): Varies between 0 and 0.5. +- Machine epsilon (`ϵ`): Constant to prevent division by zero + (no need to change default) +""" +struct PAdam{T} <: AbstractRule + eta::T + beta::Tuple{T, T} + rho::T + epsilon::T +end +PAdam(η = 1f-2, β = (9f-1, 9.99f-1), ρ = 2.5f-1, ϵ = eps(typeof(η))) = PAdam{typeof(η)}(η, β, ρ, ϵ) + +init(o::PAdam, x::AbstractArray) = (onevalue(o.epsilon, x), onevalue(o.epsilon, x), onevalue(o.epsilon, x)) + +function apply!(o::PAdam, state, x, dx) + η, β, ρ, ϵ = o.eta, o.beta, o.rho, o.epsilon + mt, vt, v̂t = state + + @.. mt = β[1] * mt + (1 - β[1]) * dx + @.. vt = β[2] * vt + (1 - β[2]) * abs2(dx) + @.. v̂t = max(v̂t, vt) + dx′ = @lazy η * mt / (v̂t ^ ρ + ϵ) + + return (mt, vt, v̂t), dx′ + end + """ WeightDecay(γ = 5f-4)