From bb94d687edce48a3da21042f295bd393103dcab8 Mon Sep 17 00:00:00 2001 From: Ben Murrell Date: Tue, 10 Dec 2024 23:12:41 +0100 Subject: [PATCH 01/13] Adding Apollo draft --- src/rules.jl | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/src/rules.jl b/src/rules.jl index 0063d70e..e1225c17 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -599,6 +599,54 @@ function apply!(o::AdaBelief, state, x::AbstractArray{T}, dx) where T return (mt, st, βt .* β), dx′ end + +""" + Apollo(opt, r, u) + +Apollo optimizer from Zhu et al. (https://arxiv.org/pdf/2412.05270). Tracks moments in a low-rank subspace, aiming for Adam-like behavior with minimal additional memory usage. +`opt` is an AdamW optimizer, `r` is the random projection rank (smaller for lower memory use), and `u` is the random projection update interval. +""" +struct Apollo{T1} <: AbstractRule + opt::T1 + r::Int #Subspace rank + u::Int #Subspace update frequency (T in paper) +end + +#Use the base init and apply for 1D arrays +init(o::Apollo, x::AbstractArray{T,1}) where T = init(o.opt, x) +apply!(o::Apollo, state, x::AbstractArray{T,1}, dx) where T = apply!(o.opt, state, x, dx) + +function init(o::Apollo, x::AbstractArray{T,2}) where T + rank = min(o.r, ceil(Int, size(x,2) / 2)) + P = randn(T, rank, size(x,1)) .* T(1/rank) + ((similar(x, rank, size(x,2)) .= 0, similar(x, rank, size(x,2)) .= 0, o.opt.beta), 0, P) +end + +function apply!(o::Apollo, state, x::AbstractArray{T,2}, dx) where T + (mt, vt, βt), t, P = state + η = T(o.opt.eta) + λ = T(o.opt.lambda) + β = T.(o.opt.beta) + ϵ = T(o.opt.epsilon) + if mod(t, o.u) == 100 + rank = min(o.r, ceil(Int, size(x,2) / 2)) + @show rank, typeof(rank) + P = randn(T, rank, size(x,1)) .* T(1/rank) + end + R = P * dx + Optimisers.@.. mt = β[1] * mt + (1 - β[1]) * R + Optimisers.@.. vt = β[2] * vt + (1 - β[2]) * abs2(R) + Rhat = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) + s = [Optimisers.norm(view(Rhat, :, j)) / Optimisers.norm(view(R, :, j)) for j in 1:size(R,2)] + S = Diagonal(s) + dx′′ = η * dx * S + λ * x + return ((mt, vt, βt .* β), t+1, P), dx′′ +end + + + + + """ WeightDecay(λ = 5e-4) WeightDecay(; [lambda]) From acbe8e326c765ee3360fb55850d8d732845539b1 Mon Sep 17 00:00:00 2001 From: Ben Murrell Date: Wed, 11 Dec 2024 00:09:57 +0100 Subject: [PATCH 02/13] Adding an epsilon to prevent NaNs --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index e1225c17..fbfc48b1 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -637,7 +637,7 @@ function apply!(o::Apollo, state, x::AbstractArray{T,2}, dx) where T Optimisers.@.. mt = β[1] * mt + (1 - β[1]) * R Optimisers.@.. vt = β[2] * vt + (1 - β[2]) * abs2(R) Rhat = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) - s = [Optimisers.norm(view(Rhat, :, j)) / Optimisers.norm(view(R, :, j)) for j in 1:size(R,2)] + s = sqrt.(sum(abs2.(Rhat), dims=1))[:] ./ (sqrt.(sum(abs2.(R), dims=1))[:] .+ ϵ) S = Diagonal(s) dx′′ = η * dx * S + λ * x return ((mt, vt, βt .* β), t+1, P), dx′′ From 8a05289aeb6945d73a9ebccec5ac86d0efbfff0e Mon Sep 17 00:00:00 2001 From: Ben Murrell Date: Wed, 11 Dec 2024 20:27:21 +0100 Subject: [PATCH 03/13] Adding GradNormGrowthLimiter, and handling dims>2 and dim sorting in Apollo --- src/rules.jl | 117 ++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 97 insertions(+), 20 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 24c18af1..0aabf134 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -601,52 +601,129 @@ end """ - Apollo(opt, r, u) + GradNormGrowthLimiter(γ = 1.1; m = 1e-3, ϵ = 1e-8, throw = true, paramscale_min = true) + +Gradient norm growth limiter from Chen et al. (https://arxiv.org/pdf/2410.01623) and used with Apollo in Zhu et al. (https://arxiv.org/pdf/2412.05270). +With Optimisers.jl this will apply per-tensor, which may not be the same as the implementations in these papers. It still seems to help, but the ideal settings may vary. +This also introduces `m` a hard minimum on the gradient norm, and never rescales grads below this, preventing a tensor from getting "trapped" near zero. +This can be a fixed min, or scaled by the number of parameters in the tensor (with `paramscale_min = true`). +""" +struct GradNormGrowthLimiter <: AbstractRule + γ::Float64 + m::Float64 #Min grad norm, to stop a tensor getting stuck near zero + ϵ::Float64 + throw::Bool + paramscale_min::Bool +end + +GradNormGrowthLimiter(γ = 1.1; m = 1e-3, ϵ = 1e-8, throw = true, paramscale_min = true) = GradNormGrowthLimiter(γ, m, ϵ, throw, paramscale_min) + +init(o::GradNormGrowthLimiter, x::AbstractArray{T}) where T = T(0) + +function apply!(o::GradNormGrowthLimiter, state, x::AbstractArray{T}, dx) where T + current_norm = Optimisers._norm(dx, 2) + if o.throw && !isfinite(current_norm) + throw(DomainError("gradient has L2-norm $current_norm, for array $(summary(x))")) + end + if state == 0 + return (current_norm), dx + else + #If you're below the hard min, then don't scale + if o.paramscale_min + minthresh = o.m * length(dx) + else + minthresh = o.m + end + if current_norm < minthresh + return current_norm, dx + end + ratio = current_norm / (state + o.ϵ) + if ratio > o.γ + λ = T((o.γ * state) / (current_norm + o.ϵ)) + print(":", current_norm, ":") + return current_norm * λ, dx * λ + else + return current_norm, dx + end + end +end + +nonfirstdims(x) = prod(size(x)[2:end]) + +""" + Apollo(η::Real, rank::Int; u = 100, sort_dims = false) + Apollo(η::Real; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = false) + Apollo(opt::Optimisers.AdamW, rank::Int; u = 100, sort_dims = false) + Apollo(opt::Optimisers.AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = false) Apollo optimizer from Zhu et al. (https://arxiv.org/pdf/2412.05270). Tracks moments in a low-rank subspace, aiming for Adam-like behavior with minimal additional memory usage. -`opt` is an AdamW optimizer, `r` is the random projection rank (smaller for lower memory use), and `u` is the random projection update interval. +First argument can be an AdamW optimizer, or a learning rate (which will use the default AdamW optimizer with that learning rate). Second argument can be a rank, or a function +to compute the rank from the second dimension (or the product of all dims > 1) of the weight matrix (or tensor). """ struct Apollo{T1} <: AbstractRule opt::T1 - r::Int #Subspace rank + r::Function #Maps non-first dims to rank u::Int #Subspace update frequency (T in paper) + sort_dims::Bool #Whether to swap the dims of x and dx when the second dim is smaller than the first end + +Apollo() = Apollo(AdamW(0.001), dim -> ceil(Int, sqrt(dim)), 100, true) +Apollo(η::Real, rank::Int; u = 100, sort_dims = true) = Apollo(AdamW(η), dim -> max(dim, rank), u, sort_dims) +Apollo(η::Real; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(AdamW(η), rank_function, u, sort_dims) +Apollo(opt::AdamW, rank::Int; u = 100, sort_dims = true) = Apollo(AdamW(η), dim -> max(dim, rank), u, sort_dims) +Apollo(opt::AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(opt, rank_function, u, sort_dims) + #Use the base init and apply for 1D arrays init(o::Apollo, x::AbstractArray{T,1}) where T = init(o.opt, x) apply!(o::Apollo, state, x::AbstractArray{T,1}, dx) where T = apply!(o.opt, state, x, dx) -function init(o::Apollo, x::AbstractArray{T,2}) where T - rank = min(o.r, ceil(Int, size(x,2) / 2)) - P = randn(T, rank, size(x,1)) .* T(1/rank) - ((similar(x, rank, size(x,2)) .= 0, similar(x, rank, size(x,2)) .= 0, o.opt.beta), 0, P) -end - -function apply!(o::Apollo, state, x::AbstractArray{T,2}, dx) where T +function init(o::Apollo, x::AbstractArray{T}) where T + first_dim, second_dim = size(x,1), nonfirstdims(x) + if o.sort_dims && second_dim < first_dim + first_dim, second_dim = second_dim, first_dim + end + rank = o.r(second_dim) + P = randn(T, rank, first_dim) .* T(1/rank) + ((similar(x, rank, second_dim) .= 0, similar(x, rank, second_dim) .= 0, o.opt.beta), 1, P) +end + +function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T + swapped = false + original_size = size(x) + x = reshape(x, size(x,1), nonfirstdims(x)) + dx = reshape(dx, size(dx,1), nonfirstdims(dx)) + + first_dim, second_dim = size(x,1), size(x,2) + if o.sort_dims && second_dim < first_dim + first_dim, second_dim = second_dim, first_dim + x = x' + dx = dx' + swapped = true + end (mt, vt, βt), t, P = state η = T(o.opt.eta) λ = T(o.opt.lambda) β = T.(o.opt.beta) ϵ = T(o.opt.epsilon) - if mod(t, o.u) == 100 - rank = min(o.r, ceil(Int, size(x,2) / 2)) - @show rank, typeof(rank) - P = randn(T, rank, size(x,1)) .* T(1/rank) + if mod(t, o.u) == 0 + rank = o.r(second_dim) + P = randn(T, rank, first_dim) .* T(1/rank) end R = P * dx - Optimisers.@.. mt = β[1] * mt + (1 - β[1]) * R - Optimisers.@.. vt = β[2] * vt + (1 - β[2]) * abs2(R) + @.. mt = β[1] * mt + (1 - β[1]) * R + @.. vt = β[2] * vt + (1 - β[2]) * abs2(R) Rhat = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) s = sqrt.(sum(abs2.(Rhat), dims=1))[:] ./ (sqrt.(sum(abs2.(R), dims=1))[:] .+ ϵ) S = Diagonal(s) dx′′ = η * dx * S + λ * x - return ((mt, vt, βt .* β), t+1, P), dx′′ + if swapped + dx′′ = dx′′' + end + return ((mt, vt, βt .* β), t+1, P), reshape(dx′′, original_size) end - - - """ WeightDecay(λ = 5e-4) WeightDecay(; [lambda]) From b46c0efd3371f36b1bf5c692e3363ff9459c37f7 Mon Sep 17 00:00:00 2001 From: Ben Murrell Date: Wed, 11 Dec 2024 20:29:49 +0100 Subject: [PATCH 04/13] Adding Apollo and GradNormGrowthLimiter to tests --- test/rules.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/rules.jl b/test/rules.jl index 499902ca..b1b19fbf 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -8,12 +8,13 @@ RULES = [ # All the rules at default settings: Descent(), Adam(), Momentum(), Nesterov(), Rprop(), RMSProp(), AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(), - AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(), + AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(), Apollo(), # A few chained combinations: OptimiserChain(SignDecay(0.001), Adam(0.001)), OptimiserChain(ClipNorm(), Adam(0.001)), OptimiserChain(ClipGrad(0.5), Momentum()), OptimiserChain(WeightDecay(), OAdam(), ClipGrad(1)), + OptimiserChain(GradNormGrowthLimiter(1.1), Apollo()), # Not the default: RMSProp(centred = true), AdamW(couple=false), ] From e39add784e7f88a64ea1a65b136c848251b30b7d Mon Sep 17 00:00:00 2001 From: Ben Murrell Date: Wed, 11 Dec 2024 20:43:41 +0100 Subject: [PATCH 05/13] Touch ups --- src/Optimisers.jl | 2 +- src/rules.jl | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 99fc162f..87a60c60 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -23,7 +23,7 @@ include("rules.jl") export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion, - AccumGrad + AccumGrad, Apollo, GradNormGrowthLimiter VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public apply!, init, setup, update, update!")) diff --git a/src/rules.jl b/src/rules.jl index 0aabf134..a08ec1d1 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -621,7 +621,7 @@ GradNormGrowthLimiter(γ = 1.1; m = 1e-3, ϵ = 1e-8, throw = true, paramscale_mi init(o::GradNormGrowthLimiter, x::AbstractArray{T}) where T = T(0) function apply!(o::GradNormGrowthLimiter, state, x::AbstractArray{T}, dx) where T - current_norm = Optimisers._norm(dx, 2) + current_norm = _norm(dx, 2) if o.throw && !isfinite(current_norm) throw(DomainError("gradient has L2-norm $current_norm, for array $(summary(x))")) end @@ -640,7 +640,6 @@ function apply!(o::GradNormGrowthLimiter, state, x::AbstractArray{T}, dx) where ratio = current_norm / (state + o.ϵ) if ratio > o.γ λ = T((o.γ * state) / (current_norm + o.ϵ)) - print(":", current_norm, ":") return current_norm * λ, dx * λ else return current_norm, dx @@ -653,8 +652,8 @@ nonfirstdims(x) = prod(size(x)[2:end]) """ Apollo(η::Real, rank::Int; u = 100, sort_dims = false) Apollo(η::Real; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = false) - Apollo(opt::Optimisers.AdamW, rank::Int; u = 100, sort_dims = false) - Apollo(opt::Optimisers.AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = false) + Apollo(opt::AdamW, rank::Int; u = 100, sort_dims = false) + Apollo(opt::AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = false) Apollo optimizer from Zhu et al. (https://arxiv.org/pdf/2412.05270). Tracks moments in a low-rank subspace, aiming for Adam-like behavior with minimal additional memory usage. First argument can be an AdamW optimizer, or a learning rate (which will use the default AdamW optimizer with that learning rate). Second argument can be a rank, or a function From 43d30c60fd367c595570c85bc5326b802143f7e2 Mon Sep 17 00:00:00 2001 From: Ben Murrell Date: Thu, 12 Dec 2024 12:52:54 +0100 Subject: [PATCH 06/13] Attempting gradient types fix --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index a08ec1d1..7bdea5e5 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -691,7 +691,7 @@ function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T swapped = false original_size = size(x) x = reshape(x, size(x,1), nonfirstdims(x)) - dx = reshape(dx, size(dx,1), nonfirstdims(dx)) + dx = reshape(dx, size(x,1), nonfirstdims(dx)) first_dim, second_dim = size(x,1), size(x,2) if o.sort_dims && second_dim < first_dim From b282c3524503d86781d504253e60efb068f905a5 Mon Sep 17 00:00:00 2001 From: Ben Murrell Date: Thu, 12 Dec 2024 13:24:58 +0100 Subject: [PATCH 07/13] Forcing materialize. --- src/rules.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index 7bdea5e5..568f745e 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -691,7 +691,9 @@ function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T swapped = false original_size = size(x) x = reshape(x, size(x,1), nonfirstdims(x)) - dx = reshape(dx, size(x,1), nonfirstdims(dx)) + + dx = Broadcast.materialize(dx) #This is to stop the "gradient type" @lazy test from failing due to reshape. + dx = reshape(dx, size(x,1), nonfirstdims(x)) first_dim, second_dim = size(x,1), size(x,2) if o.sort_dims && second_dim < first_dim From ca2ae0ab805326ad49b98b173f63861b6adf28d5 Mon Sep 17 00:00:00 2001 From: murrellb Date: Fri, 13 Dec 2024 12:00:38 +0100 Subject: [PATCH 08/13] GPU attempt --- src/Optimisers.jl | 2 ++ src/rules.jl | 8 ++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 87a60c60..e4b992ba 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -5,6 +5,8 @@ using Functors: functor, fmap, fmap_with_path, isleaf, @functor, fmapstructure, children, AbstractWalk using LinearAlgebra +using Random: randn! + include("interface.jl") export AbstractRule diff --git a/src/rules.jl b/src/rules.jl index 568f745e..a9080625 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -683,10 +683,13 @@ function init(o::Apollo, x::AbstractArray{T}) where T first_dim, second_dim = second_dim, first_dim end rank = o.r(second_dim) - P = randn(T, rank, first_dim) .* T(1/rank) + P = similar(x, rank, first_dim) + randn!(P) + P .*= T(sqrt(1/rank)) ((similar(x, rank, second_dim) .= 0, similar(x, rank, second_dim) .= 0, o.opt.beta), 1, P) end + function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T swapped = false original_size = size(x) @@ -709,7 +712,8 @@ function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T ϵ = T(o.opt.epsilon) if mod(t, o.u) == 0 rank = o.r(second_dim) - P = randn(T, rank, first_dim) .* T(1/rank) + randn!(P) + P .*= T(sqrt(1/rank)) end R = P * dx @.. mt = β[1] * mt + (1 - β[1]) * R From 6aa32c13bff5ac0ee365f5ade9ba1dbad513dff5 Mon Sep 17 00:00:00 2001 From: murrellb Date: Fri, 13 Dec 2024 14:10:11 +0100 Subject: [PATCH 09/13] Fixing types, and avoiding GPU scalarindexing from adjoint * Diagonal --- src/rules.jl | 70 ++++++++++++++++++++++++++-------------------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index a9080625..4ba550dc 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -691,41 +691,41 @@ end function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T - swapped = false - original_size = size(x) - x = reshape(x, size(x,1), nonfirstdims(x)) - - dx = Broadcast.materialize(dx) #This is to stop the "gradient type" @lazy test from failing due to reshape. - dx = reshape(dx, size(x,1), nonfirstdims(x)) - - first_dim, second_dim = size(x,1), size(x,2) - if o.sort_dims && second_dim < first_dim - first_dim, second_dim = second_dim, first_dim - x = x' - dx = dx' - swapped = true - end - (mt, vt, βt), t, P = state - η = T(o.opt.eta) - λ = T(o.opt.lambda) - β = T.(o.opt.beta) - ϵ = T(o.opt.epsilon) - if mod(t, o.u) == 0 - rank = o.r(second_dim) - randn!(P) - P .*= T(sqrt(1/rank)) - end - R = P * dx - @.. mt = β[1] * mt + (1 - β[1]) * R - @.. vt = β[2] * vt + (1 - β[2]) * abs2(R) - Rhat = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) - s = sqrt.(sum(abs2.(Rhat), dims=1))[:] ./ (sqrt.(sum(abs2.(R), dims=1))[:] .+ ϵ) - S = Diagonal(s) - dx′′ = η * dx * S + λ * x - if swapped - dx′′ = dx′′' - end - return ((mt, vt, βt .* β), t+1, P), reshape(dx′′, original_size) + swapped = false + original_size = size(x) + x = reshape(x, size(x,1), nonfirstdims(x)) + + dx = Broadcast.materialize(dx) #This is to stop the "gradient type" @lazy test from failing due to reshape. + dx = reshape(dx, size(x,1), nonfirstdims(x)) + + first_dim, second_dim = size(x,1), size(x,2) + if o.sort_dims && second_dim < first_dim + first_dim, second_dim = second_dim, first_dim + x = x' + dx = dx' + swapped = true + end + (mt, vt, βt), t, P = state + η = T(o.opt.eta) + λ = T(o.opt.lambda) + β = T.(o.opt.beta) + ϵ = T(o.opt.epsilon) + βt = T.(βt) + if mod(t, o.u) == 0 + rank = o.r(second_dim) + randn!(P) + P .*= T(sqrt(1/rank)) + end + R = P * dx + @.. mt = β[1] * mt + (1 - β[1]) * R + @.. vt = β[2] * vt + (1 - β[2]) * abs2(R) + Rhat = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) + s = sqrt.(sum(abs2.(Rhat), dims=1))[:] ./ (sqrt.(sum(abs2.(R), dims=1))[:] .+ ϵ) + dx′′ = η * (dx .* reshape(s, 1, :)) + λ * x + if swapped + dx′′ = dx′′' + end + return ((mt, vt, βt .* β), t+1, P), reshape(dx′′, original_size) end From d9637c6246a0b7b8f00c8831c56bef5be0f95ade Mon Sep 17 00:00:00 2001 From: murrellb Date: Sat, 14 Dec 2024 20:29:58 +0100 Subject: [PATCH 10/13] Give Apollo its own eta for adjust, and use sqrt(#params) for GradNormGrowthLimiter --- src/rules.jl | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 4ba550dc..08216df6 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -603,10 +603,10 @@ end """ GradNormGrowthLimiter(γ = 1.1; m = 1e-3, ϵ = 1e-8, throw = true, paramscale_min = true) -Gradient norm growth limiter from Chen et al. (https://arxiv.org/pdf/2410.01623) and used with Apollo in Zhu et al. (https://arxiv.org/pdf/2412.05270). -With Optimisers.jl this will apply per-tensor, which may not be the same as the implementations in these papers. It still seems to help, but the ideal settings may vary. -This also introduces `m` a hard minimum on the gradient norm, and never rescales grads below this, preventing a tensor from getting "trapped" near zero. -This can be a fixed min, or scaled by the number of parameters in the tensor (with `paramscale_min = true`). +Gradient norm growth limiter. Inspired by [Chen et al.](https://arxiv.org/abs/2410.01623) and used with Apollo in [Zhu et al.](https://arxiv.org/abs/2412.05270), but +with Optimisers.jl this will apply per-tensor instead of per-model, and as a result the defaults are different. `γ` controls the maximum that the gradient norm can grow +from one step to the next. This implementation also introduces `m` a hard minimum on the gradient norm threshold, and never rescales grads below this, preventing a tensor +from getting "trapped" near zero. This can be a fixed min, or scaled by the square root of the number of parameters in the tensor (with `paramscale_min = true`). """ struct GradNormGrowthLimiter <: AbstractRule γ::Float64 @@ -630,7 +630,7 @@ function apply!(o::GradNormGrowthLimiter, state, x::AbstractArray{T}, dx) where else #If you're below the hard min, then don't scale if o.paramscale_min - minthresh = o.m * length(dx) + minthresh = o.m * sqrt(length(dx)) else minthresh = o.m end @@ -659,19 +659,20 @@ Apollo optimizer from Zhu et al. (https://arxiv.org/pdf/2412.05270). Tracks mome First argument can be an AdamW optimizer, or a learning rate (which will use the default AdamW optimizer with that learning rate). Second argument can be a rank, or a function to compute the rank from the second dimension (or the product of all dims > 1) of the weight matrix (or tensor). """ -struct Apollo{T1} <: AbstractRule +struct Apollo{T1, T2, T3, T4, T5} <: AbstractRule opt::T1 - r::Function #Maps non-first dims to rank - u::Int #Subspace update frequency (T in paper) - sort_dims::Bool #Whether to swap the dims of x and dx when the second dim is smaller than the first + eta::T2 + r::T3 #Maps non-first dims to rank + u::T4 #Subspace update frequency (T in paper) + sort_dims::T5 #Whether to swap the dims of x and dx when the second dim is smaller than the first end -Apollo() = Apollo(AdamW(0.001), dim -> ceil(Int, sqrt(dim)), 100, true) -Apollo(η::Real, rank::Int; u = 100, sort_dims = true) = Apollo(AdamW(η), dim -> max(dim, rank), u, sort_dims) -Apollo(η::Real; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(AdamW(η), rank_function, u, sort_dims) -Apollo(opt::AdamW, rank::Int; u = 100, sort_dims = true) = Apollo(AdamW(η), dim -> max(dim, rank), u, sort_dims) -Apollo(opt::AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(opt, rank_function, u, sort_dims) +Apollo() = Apollo(AdamW(0.001), 0.001, dim -> ceil(Int, sqrt(dim)), 100, true) +Apollo(η::Real, rank::Int; u = 100, sort_dims = true) = Apollo(AdamW(η), η, dim -> max(dim, rank), u, sort_dims) +Apollo(η::Real; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(AdamW(η), η, rank_function, u, sort_dims) +Apollo(opt::AdamW, rank::Int; u = 100, sort_dims = true) = Apollo(opt, opt.eta, dim -> max(dim, rank), u, sort_dims) +Apollo(opt::AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(opt, opt.eta, rank_function, u, sort_dims) #Use the base init and apply for 1D arrays init(o::Apollo, x::AbstractArray{T,1}) where T = init(o.opt, x) @@ -706,7 +707,7 @@ function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T swapped = true end (mt, vt, βt), t, P = state - η = T(o.opt.eta) + η = T(o.eta) #This is what will get modified by adjust λ = T(o.opt.lambda) β = T.(o.opt.beta) ϵ = T(o.opt.epsilon) @@ -728,6 +729,9 @@ function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T return ((mt, vt, βt .* β), t+1, P), reshape(dx′′, original_size) end +#Notes: chuck the AdamW from the struct, so that adjust will just work. + + """ WeightDecay(λ = 5e-4) From c75142f9f07d0deee536dd27dd9c36ebf6c82d01 Mon Sep 17 00:00:00 2001 From: murrellb Date: Mon, 16 Dec 2024 03:10:44 +0100 Subject: [PATCH 11/13] Various tweaks. --- src/Optimisers.jl | 2 +- src/rules.jl | 95 ++++++++++++++++++++++++++--------------------- 2 files changed, 53 insertions(+), 44 deletions(-) diff --git a/src/Optimisers.jl b/src/Optimisers.jl index e4b992ba..5a7e5c2e 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -25,7 +25,7 @@ include("rules.jl") export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion, - AccumGrad, Apollo, GradNormGrowthLimiter + AccumGrad, Apollo, NormGrowthCap VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public apply!, init, setup, update, update!")) diff --git a/src/rules.jl b/src/rules.jl index 08216df6..53af624c 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -599,28 +599,29 @@ function apply!(o::AdaBelief, state, x::AbstractArray{T}, dx) where T return (mt, st, βt .* β), dx′ end - """ - GradNormGrowthLimiter(γ = 1.1; m = 1e-3, ϵ = 1e-8, throw = true, paramscale_min = true) + NormGrowthCap(τ = 1.01; ϵ = 1e-8, lb = 1e-7, throw = true, scale = true) -Gradient norm growth limiter. Inspired by [Chen et al.](https://arxiv.org/abs/2410.01623) and used with Apollo in [Zhu et al.](https://arxiv.org/abs/2412.05270), but -with Optimisers.jl this will apply per-tensor instead of per-model, and as a result the defaults are different. `γ` controls the maximum that the gradient norm can grow -from one step to the next. This implementation also introduces `m` a hard minimum on the gradient norm threshold, and never rescales grads below this, preventing a tensor -from getting "trapped" near zero. This can be a fixed min, or scaled by the square root of the number of parameters in the tensor (with `paramscale_min = true`). +Gradient norm growth limiter. `τ` controls the maximum that the gradient norm can grow from one step to the next, such that +if `||dx||/||dx_prev|| > τ` & `||dx|| > lb`, then `dx = dx * τ*||dx_prev||/(||dx||+ϵ)` +Inspired by [Chen et al.](https://arxiv.org/abs/2410.01623) and used with Apollo in [Zhu et al.](https://arxiv.org/abs/2412.05270), but +with Optimisers.jl this will apply per-tensor instead of per-model. This implementation also introduces `lb` as a hard minimum on the gradient norm threshold, +and never rescales grads below this, preventing a tensor from getting "trapped" near zero. This can be a fixed min, or scaled by the square root of the +number of parameters in the tensor (with `scale = true`). """ -struct GradNormGrowthLimiter <: AbstractRule - γ::Float64 - m::Float64 #Min grad norm, to stop a tensor getting stuck near zero - ϵ::Float64 +struct NormGrowthCap <: AbstractRule + tau::Float64 + epsilon::Float64 + lb::Float64 #Min grad norm, to stop a tensor getting stuck near zero throw::Bool - paramscale_min::Bool + scale::Bool end -GradNormGrowthLimiter(γ = 1.1; m = 1e-3, ϵ = 1e-8, throw = true, paramscale_min = true) = GradNormGrowthLimiter(γ, m, ϵ, throw, paramscale_min) +NormGrowthCap(τ = 1.01; ϵ = 1e-8, lb = 1e-7, throw = true, scale = true) = NormGrowthCap(τ, ϵ, lb, throw, scale) -init(o::GradNormGrowthLimiter, x::AbstractArray{T}) where T = T(0) +init(o::NormGrowthCap, x::AbstractArray{T}) where T = T(0) -function apply!(o::GradNormGrowthLimiter, state, x::AbstractArray{T}, dx) where T +function apply!(o::NormGrowthCap, state, x::AbstractArray{T}, dx) where T current_norm = _norm(dx, 2) if o.throw && !isfinite(current_norm) throw(DomainError("gradient has L2-norm $current_norm, for array $(summary(x))")) @@ -629,18 +630,18 @@ function apply!(o::GradNormGrowthLimiter, state, x::AbstractArray{T}, dx) where return (current_norm), dx else #If you're below the hard min, then don't scale - if o.paramscale_min - minthresh = o.m * sqrt(length(dx)) + if o.scale + minthresh = o.lb * sqrt(length(dx)) else - minthresh = o.m + minthresh = o.lb end if current_norm < minthresh return current_norm, dx end - ratio = current_norm / (state + o.ϵ) - if ratio > o.γ - λ = T((o.γ * state) / (current_norm + o.ϵ)) - return current_norm * λ, dx * λ + ratio = current_norm / (state + o.epsilon) + if ratio > o.tau + lambda = T((o.tau * state) / (current_norm + o.epsilon)) + return current_norm * lambda, dx * lambda else return current_norm, dx end @@ -650,29 +651,36 @@ end nonfirstdims(x) = prod(size(x)[2:end]) """ - Apollo(η::Real, rank::Int; u = 100, sort_dims = false) - Apollo(η::Real; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = false) - Apollo(opt::AdamW, rank::Int; u = 100, sort_dims = false) - Apollo(opt::AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = false) + Apollo(opt::AdamW = AdamW(), r::Function = dim -> ceil(Int, sqrt(dim)); u = 100, sort_dims = true) + Apollo(η::Real, args...; kw...) + Apollo(arg, rank::Int; kw...) + Apollo(η::Real, rank::Int; kw...) -Apollo optimizer from Zhu et al. (https://arxiv.org/pdf/2412.05270). Tracks moments in a low-rank subspace, aiming for Adam-like behavior with minimal additional memory usage. +Apollo optimizer from Zhu et al. (https://arxiv.org/abs/2412.05270). Tracks moments in a low-rank subspace, aiming for Adam-like behavior with minimal additional memory usage. First argument can be an AdamW optimizer, or a learning rate (which will use the default AdamW optimizer with that learning rate). Second argument can be a rank, or a function to compute the rank from the second dimension (or the product of all dims > 1) of the weight matrix (or tensor). """ -struct Apollo{T1, T2, T3, T4, T5} <: AbstractRule +struct Apollo{T1, T2} <: AbstractRule opt::T1 - eta::T2 - r::T3 #Maps non-first dims to rank - u::T4 #Subspace update frequency (T in paper) - sort_dims::T5 #Whether to swap the dims of x and dx when the second dim is smaller than the first + r::T2 #Maps non-first dims to rank + u::Int #Subspace update frequency (T in paper) + sort_dims::Bool #Whether to swap the dims of x and dx when the second dim is smaller than the first end +function adjust(r::Apollo; kw...) + if (:u in keys(kw)) || (:r in keys(kw)) || (:sort_dims in keys(kw)) + @error "Apollo does not support adjusting: u, r, sort_dims" + end + return Apollo(adjust(r.opt, NamedTuple(kw)), r.r, r.u, r.sort_dims) +end +adjust(r::Apollo, η::Real) = Apollo(adjust(r.opt, η), r.r, r.u, r.sort_dims) + + +Apollo(opt::AdamW = AdamW(), r::Function = dim -> ceil(Int, sqrt(dim)); u = 100, sort_dims = true) = Apollo(opt, r, u, sort_dims) +Apollo(η::Real, args...; kw...) = Apollo(AdamW(η), args...; kw...) +Apollo(arg, rank::Int; kw...) = Apollo(arg, dim -> min(dim, rank); kw...) +Apollo(η::Real, rank::Int; kw...) = Apollo(AdamW(η), rank; kw...) -Apollo() = Apollo(AdamW(0.001), 0.001, dim -> ceil(Int, sqrt(dim)), 100, true) -Apollo(η::Real, rank::Int; u = 100, sort_dims = true) = Apollo(AdamW(η), η, dim -> max(dim, rank), u, sort_dims) -Apollo(η::Real; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(AdamW(η), η, rank_function, u, sort_dims) -Apollo(opt::AdamW, rank::Int; u = 100, sort_dims = true) = Apollo(opt, opt.eta, dim -> max(dim, rank), u, sort_dims) -Apollo(opt::AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(opt, opt.eta, rank_function, u, sort_dims) #Use the base init and apply for 1D arrays init(o::Apollo, x::AbstractArray{T,1}) where T = init(o.opt, x) @@ -707,7 +715,7 @@ function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T swapped = true end (mt, vt, βt), t, P = state - η = T(o.eta) #This is what will get modified by adjust + η = T(o.opt.eta) #This is what will get modified by adjust λ = T(o.opt.lambda) β = T.(o.opt.beta) ϵ = T(o.opt.epsilon) @@ -721,17 +729,18 @@ function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T @.. mt = β[1] * mt + (1 - β[1]) * R @.. vt = β[2] * vt + (1 - β[2]) * abs2(R) Rhat = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) - s = sqrt.(sum(abs2.(Rhat), dims=1))[:] ./ (sqrt.(sum(abs2.(R), dims=1))[:] .+ ϵ) - dx′′ = η * (dx .* reshape(s, 1, :)) + λ * x + + R2sum = sum(abs2, R; dims=1) + Rhat2sum = sum(abs2, Rhat; dims=1) + s = @. sqrt(Rhat2sum) / (sqrt(R2sum) + ϵ) + dx′′ = η * (dx .* s) + λ * x + if swapped - dx′′ = dx′′' + dx′′ = transpose(dx′′) end return ((mt, vt, βt .* β), t+1, P), reshape(dx′′, original_size) end -#Notes: chuck the AdamW from the struct, so that adjust will just work. - - """ WeightDecay(λ = 5e-4) From b95fd3c61edc1ff8de9b5f6f42b156e6f27534e3 Mon Sep 17 00:00:00 2001 From: murrellb Date: Mon, 16 Dec 2024 03:13:50 +0100 Subject: [PATCH 12/13] And changing the changed name in the tests. --- test/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rules.jl b/test/rules.jl index b1b19fbf..4ce4dec4 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -14,7 +14,7 @@ RULES = [ OptimiserChain(ClipNorm(), Adam(0.001)), OptimiserChain(ClipGrad(0.5), Momentum()), OptimiserChain(WeightDecay(), OAdam(), ClipGrad(1)), - OptimiserChain(GradNormGrowthLimiter(1.1), Apollo()), + OptimiserChain(NormGrowthCap(1.1), Apollo()), # Not the default: RMSProp(centred = true), AdamW(couple=false), ] From 97b0332845c538586c72735658a5e15f9d7b6c3b Mon Sep 17 00:00:00 2001 From: murrellb Date: Thu, 19 Dec 2024 19:46:47 +0100 Subject: [PATCH 13/13] Fixing adjust issue --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index 53af624c..d89bf802 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -671,7 +671,7 @@ function adjust(r::Apollo; kw...) if (:u in keys(kw)) || (:r in keys(kw)) || (:sort_dims in keys(kw)) @error "Apollo does not support adjusting: u, r, sort_dims" end - return Apollo(adjust(r.opt, NamedTuple(kw)), r.r, r.u, r.sort_dims) + return Apollo(_adjust(r.opt, NamedTuple(kw)), r.r, r.u, r.sort_dims) end adjust(r::Apollo, η::Real) = Apollo(adjust(r.opt, η), r.r, r.u, r.sort_dims)