From 950e4a91cabda470a7c348856732621fe2248687 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 13 Oct 2022 16:00:36 -0400 Subject: [PATCH 1/8] add freeze/thaw --- docs/src/api.md | 2 ++ src/adjust.jl | 53 ++++++++++++++++++++++++++++++++++++++++++++++++ src/interface.jl | 6 +++++- 3 files changed, 60 insertions(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 1017af94..ad00a2aa 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -35,6 +35,8 @@ Optimisers.setup Optimisers.update Optimisers.update! Optimisers.adjust(::Any, ::Real) +Optimisers.freeze! +Optimisers.thaw! ``` Calling `Functors.@functor` on your model's layer types by default causes diff --git a/src/adjust.jl b/src/adjust.jl index 78b3d452..f3df1025 100644 --- a/src/adjust.jl +++ b/src/adjust.jl @@ -1,3 +1,56 @@ +### +### freeze! +### + +""" + Optimisers.freeze!(tree) + +Temporarily alters the state `tree = setup(rule, model)` so that parameters will not be updated. +Can be applied to the state corresponding to only part of a model, for instance `model.layers[1]`. +Un-done by [`thaw!`](@ref Optimisers.thaw). + +# Example +```jldoctest +julia> m = (x = ([1.0], 2.0), y = [3.0]); + +julia> s = Optimisers.setup(Momentum(), m); + +julia> Optimisers.freeze!(s.x) + +julia> Optimisers.update!(s, m, (x = ([pi], 10pi), y = [100pi])); # with fake gradient + +julia> m +(x = ([1.0], 2.0), y = [-0.14159258336972558]) + +julia> s # Leaf(..., true) means frozen +(x = (Leaf(Momentum{Float32}(0.01, 0.9), [0.0], true), ()), y = Leaf(Momentum{Float32}(0.01, 0.9), [3.14159])) + +julia> Optimisers.thaw!(s) + +julia> s.x[1] +Leaf(Momentum{Float32}(0.01, 0.9), [0.0]) +``` +""" +freeze!(tree) = (fmapstructure(freeze!, tree; exclude = x -> x isa Leaf); nothing) +freeze!(ℓ::Leaf) = (ℓ.frozen = true; nothing) + +""" + Optimisers.thaw!(tree) + +Un-does [`freeze!`](@ref Optimisers.freeze!) for all parameters, +mutating every `Leaf(rule, state, true)` to `Leaf(rule, state, false)`. +""" +thaw!(tree) = (fmapstructure(thaw!, tree; exclude = x -> x isa Leaf); nothing) +thaw!(ℓ::Leaf) = (ℓ.frozen = false; nothing) + +freeze!(::Union{Number, AbstractArray{<:Number}}) = throw(ArgumentError( + "`freeze!` must not be applied to a model, only to the state tree from `setup`")) +thaw!(::Union{Number, AbstractArray{<:Number}}) = throw(ArgumentError( + "`thaw!` must not be applied to a model, only to the state tree from `setup`")) + +### +### adjust +### """ Optimisers.adjust(tree, η) -> tree diff --git a/src/interface.jl b/src/interface.jl index 79d03396..c94146f8 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -10,10 +10,12 @@ abstract type AbstractRule end ### setup ### -mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing +mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing... rule::R state::S + frozen::Bool # ... and to allow freeze! to act on this. end +Leaf(rule, state) = Leaf(rule, state, false) @functor Leaf @@ -46,6 +48,7 @@ function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long ioc = IOContext(io, :compact => true) print(ioc, "Leaf(", ℓ.rule, ", ") show(ioc, ℓ.state) + ℓ.frozen && print(ioc, ", true") print(ioc, ")") end @@ -83,6 +86,7 @@ function _update!(tree, x; grads, params) end function _update!(ℓ::Leaf, x; grads, params) haskey(params, (ℓ,x)) && return params[(ℓ,x)] + ℓ.frozen && return x params[(ℓ,x)] = if haskey(grads, ℓ) ℓ.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, grads[ℓ]...) subtract!(x, x̄′) From 1730dcdf0d28166ebad4f1a502316306511a0ffb Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 13 Oct 2022 17:07:46 -0400 Subject: [PATCH 2/8] make a keyword frozen, and print it --- src/adjust.jl | 8 ++++---- src/interface.jl | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/adjust.jl b/src/adjust.jl index f3df1025..67882bc8 100644 --- a/src/adjust.jl +++ b/src/adjust.jl @@ -22,13 +22,13 @@ julia> Optimisers.update!(s, m, (x = ([pi], 10pi), y = [100pi])); # with fake g julia> m (x = ([1.0], 2.0), y = [-0.14159258336972558]) -julia> s # Leaf(..., true) means frozen -(x = (Leaf(Momentum{Float32}(0.01, 0.9), [0.0], true), ()), y = Leaf(Momentum{Float32}(0.01, 0.9), [3.14159])) +julia> s +(x = (Leaf(Momentum{Float32}(0.01, 0.9), [0.0], frozen=true), ()), y = Leaf(Momentum{Float32}(0.01, 0.9), [3.14159])) julia> Optimisers.thaw!(s) -julia> s.x[1] -Leaf(Momentum{Float32}(0.01, 0.9), [0.0]) +julia> s.x +(Leaf(Momentum{Float32}(0.01, 0.9), [0.0]), ()) ``` """ freeze!(tree) = (fmapstructure(freeze!, tree; exclude = x -> x isa Leaf); nothing) diff --git a/src/interface.jl b/src/interface.jl index c94146f8..7aa3bc58 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -15,7 +15,7 @@ mutable struct Leaf{R,S} # mutable so that its identity encodes parameter shari state::S frozen::Bool # ... and to allow freeze! to act on this. end -Leaf(rule, state) = Leaf(rule, state, false) +Leaf(rule, state; frozen::Bool = false) = Leaf(rule, state, frozen) @functor Leaf @@ -44,12 +44,12 @@ function _setup(rule, x; cache) end end -function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type! +function Base.show(io::IO, ℓ::Leaf; colour = ℓ.frozen ? :cyan : :green) ioc = IOContext(io, :compact => true) - print(ioc, "Leaf(", ℓ.rule, ", ") + str = sprint(show, ℓ.rule; context = ioc) + printstyled(io, "Leaf(", str, ", "; color = colour) show(ioc, ℓ.state) - ℓ.frozen && print(ioc, ", true") - print(ioc, ")") + printstyled(io, ℓ.frozen ? ", frozen=true)" : ")"; color = colour) end ### From d5d7cd1058265afc47bdf102ea29ef70522435fd Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 13 Oct 2022 17:22:30 -0400 Subject: [PATCH 3/8] tweak, simplify recursion --- src/adjust.jl | 17 ++++++++++------- src/interface.jl | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/adjust.jl b/src/adjust.jl index 67882bc8..f3751fed 100644 --- a/src/adjust.jl +++ b/src/adjust.jl @@ -5,9 +5,12 @@ """ Optimisers.freeze!(tree) -Temporarily alters the state `tree = setup(rule, model)` so that parameters will not be updated. -Can be applied to the state corresponding to only part of a model, for instance `model.layers[1]`. -Un-done by [`thaw!`](@ref Optimisers.thaw). +Temporarily alters the state `tree = setup(rule, model)` so that parameters +will not be updated. Un-done by [`thaw!`](@ref Optimisers.thaw!). + +Can be applied to the state corresponding to only part of a model, +for instance with `model::Chain`, to freeze `model.layers[1]` you +should call `freeze!(tree.layers[1])`. # Example ```jldoctest @@ -31,16 +34,16 @@ julia> s.x (Leaf(Momentum{Float32}(0.01, 0.9), [0.0]), ()) ``` """ -freeze!(tree) = (fmapstructure(freeze!, tree; exclude = x -> x isa Leaf); nothing) +freeze!(tree) = foreach(freeze!, tree) freeze!(ℓ::Leaf) = (ℓ.frozen = true; nothing) """ Optimisers.thaw!(tree) -Un-does [`freeze!`](@ref Optimisers.freeze!) for all parameters, -mutating every `Leaf(rule, state, true)` to `Leaf(rule, state, false)`. +The reverse of [`freeze!`](@ref Optimisers.freeze!). Applies to all parameters, +mutating every `Leaf(rule, state, frozen = true)` to `Leaf(rule, state, frozen = false)`. """ -thaw!(tree) = (fmapstructure(thaw!, tree; exclude = x -> x isa Leaf); nothing) +thaw!(tree) = foreach(thaw!, tree) thaw!(ℓ::Leaf) = (ℓ.frozen = false; nothing) freeze!(::Union{Number, AbstractArray{<:Number}}) = throw(ArgumentError( diff --git a/src/interface.jl b/src/interface.jl index 7aa3bc58..4a44ce81 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -49,7 +49,7 @@ function Base.show(io::IO, ℓ::Leaf; colour = ℓ.frozen ? :cyan : :green) str = sprint(show, ℓ.rule; context = ioc) printstyled(io, "Leaf(", str, ", "; color = colour) show(ioc, ℓ.state) - printstyled(io, ℓ.frozen ? ", frozen=true)" : ")"; color = colour) + printstyled(io, ℓ.frozen ? ", frozen = true)" : ")"; color = colour) end ### From 838eee2bb53bf105b6bab40c94679c08788df018 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 13 Oct 2022 17:37:04 -0400 Subject: [PATCH 4/8] also block adjust --- src/adjust.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/adjust.jl b/src/adjust.jl index f3751fed..b5ddd77c 100644 --- a/src/adjust.jl +++ b/src/adjust.jl @@ -12,6 +12,8 @@ Can be applied to the state corresponding to only part of a model, for instance with `model::Chain`, to freeze `model.layers[1]` you should call `freeze!(tree.layers[1])`. +Also prevents [`adjust`](@ref Optimisers.adjust) from changing the rule's parameters. + # Example ```jldoctest julia> m = (x = ([1.0], 2.0), y = [3.0]); @@ -64,6 +66,8 @@ through training. To change just the learning rate, provide a number `η::Real`. +Does not affect any frozen parameters, set by [`freeze!`](@ref Optimisers.freeze!). + # Example ```jldoctest julia> m = (vec = rand(Float32, 2), fun = sin); @@ -103,8 +107,8 @@ adjust(tree; kw...) = map(st -> adjust(st; kw...), tree) adjust(::Nothing, ::Real) = nothing adjust(::Nothing; kw...) = nothing -adjust(ℓ::Leaf, eta::Real) = Leaf(adjust(ℓ.rule, eta), ℓ.state) -adjust(ℓ::Leaf; kw...) = Leaf(adjust(ℓ.rule; kw...), ℓ.state) +adjust(ℓ::Leaf, eta::Real) = ℓ.frozen ? ℓ : Leaf(adjust(ℓ.rule, eta), ℓ.state) +adjust(ℓ::Leaf; kw...) = ℓ.frozen ? ℓ : Leaf(adjust(ℓ.rule; kw...), ℓ.state) """ From 126893540ff89578bf7d00657c243c3582d35b9b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 13 Oct 2022 20:53:44 -0400 Subject: [PATCH 5/8] add tests --- src/interface.jl | 2 +- test/runtests.jl | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/interface.jl b/src/interface.jl index 4a44ce81..401c9b1c 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -46,7 +46,7 @@ end function Base.show(io::IO, ℓ::Leaf; colour = ℓ.frozen ? :cyan : :green) ioc = IOContext(io, :compact => true) - str = sprint(show, ℓ.rule; context = ioc) + str = sprint(show, ℓ.rule; context = ioc) # produces Adam{Float32}(0.001, ... not 0.001f0 printstyled(io, "Leaf(", str, ", "; color = colour) show(ioc, ℓ.state) printstyled(io, ℓ.frozen ? ", frozen = true)" : ")"; color = colour) diff --git a/test/runtests.jl b/test/runtests.jl index 51e76053..b9369c22 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -221,6 +221,24 @@ end @test sc2.γ.state[2][1] ≈ [0.1, 0.2, 0.2] end + @testset "freeze/thaw" begin + m = (x=[1.0, 2.0], y=([3.0, 4.0], sin)); + st = Optimisers.setup(Descent(0.1), m); + Optimisers.freeze!(st.y) + st, m = Optimisers.update(st, m, (x=[1,10], y=([100,1000], nothing))); + @test m.x ≈ [0.9, 1.0] + @test m.y[1] == [3, 4] + + st = Optimisers.adjust(st, 0.2) + Optimisers.thaw!(st) + st, m = Optimisers.update(st, m, (x=[1,10], y=([100,1000], nothing))); + @test m.y[1] ≈ [-7.0, -96.0] + @test m.x ≈ [0.7, -1.0] + + @test_throws ArgumentError Optimisers.freeze!(m) + @test_throws ArgumentError Optimisers.thaw!(m) + end + @testset "forgotten gradient" begin x = [1.0, 2.0] sx = Optimisers.setup(Descent(), x) From 85c20bd7369a15a5384eec9610ff69598fc330d5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 13 Oct 2022 21:29:24 -0400 Subject: [PATCH 6/8] add docs --- docs/src/index.md | 27 ++++++++++++++++++++++++++- src/adjust.jl | 2 +- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 65b441bb..7c3c32d7 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -138,6 +138,31 @@ Optimisers.trainable(x::Layer) = (; alpha = x.alpha) # must be a subset of chid st = Optimisers.setup(DecayDescent(0.1), Layer(3)) ``` +## Frozen Parameters + +To temporarily prevent training from affecting some parameters, +use [freeze!](@ref Optimisers.freeze!) and `thaw!`. +They work by mutating all `Leaf`s of the state tree, within the given limb: + +```julia +using Flux, Optimisers + +x = randn(Float32, 28, 28, 1, 1); +net = @autosize (size(x)...,) Chain( + Conv((3, 3), 1 => 3, stride=2, bias=false), Flux.flatten, Dense(_ => 2, relu), +) +opt = Optimisers.setup(Optimisers.Momentum(), net); + +net.layers[3] isa Dense # now freeze this layer's parameters: +Optimisers.freeze!(opt.layers[3]) + +Optimisers.update!(opt, net, gradient(m -> sum(m(x)), net)...); + +opt # bias = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0], frozen = true) + +Optimisers.thaw!(opt) +``` + ## Tied Parameters If the same array appears twice (or more) in the model, [Functors.jl](https://fluxml.ai/Functors.jl) should recognise this. @@ -159,7 +184,7 @@ st.layers.enc.layers[1].weight === st.layers.dec.layers[1].weight.parent # true This identification relies on `===`, and will work for ordinary `Array`s and `CuArray`s. It will not at present work for `reshape`d arrays, nor for immutable arrays such as those from StaticArrays.jl. - + ## Obtaining a flat parameter vector diff --git a/src/adjust.jl b/src/adjust.jl index b5ddd77c..90376fd6 100644 --- a/src/adjust.jl +++ b/src/adjust.jl @@ -28,7 +28,7 @@ julia> m (x = ([1.0], 2.0), y = [-0.14159258336972558]) julia> s -(x = (Leaf(Momentum{Float32}(0.01, 0.9), [0.0], frozen=true), ()), y = Leaf(Momentum{Float32}(0.01, 0.9), [3.14159])) +(x = (Leaf(Momentum{Float32}(0.01, 0.9), [0.0], frozen = true), ()), y = Leaf(Momentum{Float32}(0.01, 0.9), [3.14159])) julia> Optimisers.thaw!(s) From e3ca35e2b11e8d74615ebace64f64b3d5e66ac96 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 14 Oct 2022 15:21:13 -0400 Subject: [PATCH 7/8] decouple from adjust, tweak words --- docs/src/index.md | 3 ++- src/adjust.jl | 10 +++------- test/runtests.jl | 2 +- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 7c3c32d7..976d6e01 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -142,7 +142,7 @@ st = Optimisers.setup(DecayDescent(0.1), Layer(3)) To temporarily prevent training from affecting some parameters, use [freeze!](@ref Optimisers.freeze!) and `thaw!`. -They work by mutating all `Leaf`s of the state tree, within the given limb: +They work by mutating all `Leaf`s of the state tree, or part of it. ```julia using Flux, Optimisers @@ -158,6 +158,7 @@ Optimisers.freeze!(opt.layers[3]) Optimisers.update!(opt, net, gradient(m -> sum(m(x)), net)...); +net.layers[3].bias # stil zero, and its momentum is zero too: opt # bias = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0], frozen = true) Optimisers.thaw!(opt) diff --git a/src/adjust.jl b/src/adjust.jl index 90376fd6..a8123676 100644 --- a/src/adjust.jl +++ b/src/adjust.jl @@ -1,5 +1,5 @@ ### -### freeze! +### freezing ### """ @@ -12,8 +12,6 @@ Can be applied to the state corresponding to only part of a model, for instance with `model::Chain`, to freeze `model.layers[1]` you should call `freeze!(tree.layers[1])`. -Also prevents [`adjust`](@ref Optimisers.adjust) from changing the rule's parameters. - # Example ```jldoctest julia> m = (x = ([1.0], 2.0), y = [3.0]); @@ -66,8 +64,6 @@ through training. To change just the learning rate, provide a number `η::Real`. -Does not affect any frozen parameters, set by [`freeze!`](@ref Optimisers.freeze!). - # Example ```jldoctest julia> m = (vec = rand(Float32, 2), fun = sin); @@ -107,8 +103,8 @@ adjust(tree; kw...) = map(st -> adjust(st; kw...), tree) adjust(::Nothing, ::Real) = nothing adjust(::Nothing; kw...) = nothing -adjust(ℓ::Leaf, eta::Real) = ℓ.frozen ? ℓ : Leaf(adjust(ℓ.rule, eta), ℓ.state) -adjust(ℓ::Leaf; kw...) = ℓ.frozen ? ℓ : Leaf(adjust(ℓ.rule; kw...), ℓ.state) +adjust(ℓ::Leaf, eta::Real) = Leaf(adjust(ℓ.rule, eta), ℓ.state, ℓ.frozen) +adjust(ℓ::Leaf; kw...) = Leaf(adjust(ℓ.rule; kw...), ℓ.state, ℓ.frozen) """ diff --git a/test/runtests.jl b/test/runtests.jl index b9369c22..a8cef6f6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -232,7 +232,7 @@ end st = Optimisers.adjust(st, 0.2) Optimisers.thaw!(st) st, m = Optimisers.update(st, m, (x=[1,10], y=([100,1000], nothing))); - @test m.y[1] ≈ [-7.0, -96.0] + @test m.y[1] ≈ [-17.0, -196.0] @test m.x ≈ [0.7, -1.0] @test_throws ArgumentError Optimisers.freeze!(m) From 562be6260fe6c0158ab86eb2ea8f66e840341a3b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 24 Nov 2022 22:35:31 -0500 Subject: [PATCH 8/8] tweak the doc example --- docs/src/index.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 976d6e01..659ae837 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -155,13 +155,14 @@ opt = Optimisers.setup(Optimisers.Momentum(), net); net.layers[3] isa Dense # now freeze this layer's parameters: Optimisers.freeze!(opt.layers[3]) +opt.layers[3].bias # confirm: Leaf(Momentum(...), [0.0, 0.0], frozen = true) Optimisers.update!(opt, net, gradient(m -> sum(m(x)), net)...); -net.layers[3].bias # stil zero, and its momentum is zero too: -opt # bias = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0], frozen = true) +net.layers[3].bias # stil zero, and its momentum is too: Optimisers.thaw!(opt) +opt.layers[3].bias # Leaf(Momentum(...), [0.0, 0.0]) ``` ## Tied Parameters