@@ -138,6 +138,31 @@ Optimisers.trainable(x::Layer) = (; alpha = x.alpha) # must be a subset of chid
138
138
st = Optimisers. setup (DecayDescent (0.1 ), Layer (3 ))
139
139
```
140
140
141
+ ## Frozen Parameters
142
+
143
+ To temporarily prevent training from affecting some parameters,
144
+ use [ freeze!] (@ref Optimisers.freeze!) and ` thaw! ` .
145
+ They work by mutating all ` Leaf ` s of the state tree, within the given limb:
146
+
147
+ ``` julia
148
+ using Flux, Optimisers
149
+
150
+ x = randn (Float32, 28 , 28 , 1 , 1 );
151
+ net = @autosize (size (x)... ,) Chain (
152
+ Conv ((3 , 3 ), 1 => 3 , stride= 2 , bias= false ), Flux. flatten, Dense (_ => 2 , relu),
153
+ )
154
+ opt = Optimisers. setup (Optimisers. Momentum (), net);
155
+
156
+ net. layers[3 ] isa Dense # now freeze this layer's parameters:
157
+ Optimisers. freeze! (opt. layers[3 ])
158
+
159
+ Optimisers. update! (opt, net, gradient (m -> sum (m (x)), net)... );
160
+
161
+ opt # bias = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0], frozen = true)
162
+
163
+ Optimisers. thaw! (opt)
164
+ ```
165
+
141
166
## Tied Parameters
142
167
143
168
If the same array appears twice (or more) in the model, [ Functors.jl] ( https://fluxml.ai/Functors.jl ) should recognise this.
@@ -159,7 +184,7 @@ st.layers.enc.layers[1].weight === st.layers.dec.layers[1].weight.parent # true
159
184
This identification relies on ` === ` , and will work for ordinary ` Array ` s and ` CuArray ` s.
160
185
It will not at present work for ` reshape ` d arrays, nor for immutable arrays such as those
161
186
from StaticArrays.jl.
162
-
187
+
163
188
164
189
## Obtaining a flat parameter vector
165
190
0 commit comments