Skip to content

Commit e38a9eb

Browse files
committed
add docs
1 parent 9f1daaf commit e38a9eb

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

docs/src/index.md

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,31 @@ Optimisers.trainable(x::Layer) = (; alpha = x.alpha) # must be a subset of chid
138138
st = Optimisers.setup(DecayDescent(0.1), Layer(3))
139139
```
140140

141+
## Frozen Parameters
142+
143+
To temporarily prevent training from affecting some parameters,
144+
use [freeze!](@ref Optimisers.freeze!) and `thaw!`.
145+
They work by mutating all `Leaf`s of the state tree, within the given limb:
146+
147+
```julia
148+
using Flux, Optimisers
149+
150+
x = randn(Float32, 28, 28, 1, 1);
151+
net = @autosize (size(x)...,) Chain(
152+
Conv((3, 3), 1 => 3, stride=2, bias=false), Flux.flatten, Dense(_ => 2, relu),
153+
)
154+
opt = Optimisers.setup(Optimisers.Momentum(), net);
155+
156+
net.layers[3] isa Dense # now freeze this layer's parameters:
157+
Optimisers.freeze!(opt.layers[3])
158+
159+
Optimisers.update!(opt, net, gradient(m -> sum(m(x)), net)...);
160+
161+
opt # bias = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0], frozen = true)
162+
163+
Optimisers.thaw!(opt)
164+
```
165+
141166
## Tied Parameters
142167

143168
If the same array appears twice (or more) in the model, [Functors.jl](https://fluxml.ai/Functors.jl) should recognise this.
@@ -159,7 +184,7 @@ st.layers.enc.layers[1].weight === st.layers.dec.layers[1].weight.parent # true
159184
This identification relies on `===`, and will work for ordinary `Array`s and `CuArray`s.
160185
It will not at present work for `reshape`d arrays, nor for immutable arrays such as those
161186
from StaticArrays.jl.
162-
187+
163188

164189
## Obtaining a flat parameter vector
165190

src/adjust.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ julia> m
2828
(x = ([1.0], 2.0), y = [-0.14159258336972558])
2929
3030
julia> s
31-
(x = (Leaf(Momentum{Float32}(0.01, 0.9), [0.0], frozen=true), ()), y = Leaf(Momentum{Float32}(0.01, 0.9), [3.14159]))
31+
(x = (Leaf(Momentum{Float32}(0.01, 0.9), [0.0], frozen = true), ()), y = Leaf(Momentum{Float32}(0.01, 0.9), [3.14159]))
3232
3333
julia> Optimisers.thaw!(s)
3434

0 commit comments

Comments
 (0)