Skip to content

Commit 9623ad7

Browse files
committed
decouple from adjust, tweak words
1 parent e38a9eb commit 9623ad7

File tree

3 files changed

+5
-9
lines changed

3 files changed

+5
-9
lines changed

docs/src/index.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ st = Optimisers.setup(DecayDescent(0.1), Layer(3))
142142

143143
To temporarily prevent training from affecting some parameters,
144144
use [freeze!](@ref Optimisers.freeze!) and `thaw!`.
145-
They work by mutating all `Leaf`s of the state tree, within the given limb:
145+
They work by mutating all `Leaf`s of the state tree, or part of it.
146146

147147
```julia
148148
using Flux, Optimisers

src/adjust.jl

+3-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
###
2-
### freeze!
2+
### freezing
33
###
44

55
"""
@@ -12,8 +12,6 @@ Can be applied to the state corresponding to only part of a model,
1212
for instance with `model::Chain`, to freeze `model.layers[1]` you
1313
should call `freeze!(tree.layers[1])`.
1414
15-
Also prevents [`adjust`](@ref Optimisers.adjust) from changing the rule's parameters.
16-
1715
# Example
1816
```jldoctest
1917
julia> m = (x = ([1.0], 2.0), y = [3.0]);
@@ -66,8 +64,6 @@ through training.
6664
6765
To change just the learning rate, provide a number `η::Real`.
6866
69-
Does not affect any frozen parameters, set by [`freeze!`](@ref Optimisers.freeze!).
70-
7167
# Example
7268
```jldoctest
7369
julia> m = (vec = rand(Float32, 2), fun = sin);
@@ -107,8 +103,8 @@ adjust(tree; kw...) = map(st -> adjust(st; kw...), tree)
107103
adjust(::Nothing, ::Real) = nothing
108104
adjust(::Nothing; kw...) = nothing
109105

110-
adjust(ℓ::Leaf, eta::Real) = .frozen ?: Leaf(adjust(ℓ.rule, eta), ℓ.state)
111-
adjust(ℓ::Leaf; kw...) = .frozen ?: Leaf(adjust(ℓ.rule; kw...), ℓ.state)
106+
adjust(ℓ::Leaf, eta::Real) = Leaf(adjust(ℓ.rule, eta), ℓ.state, ℓ.frozen)
107+
adjust(ℓ::Leaf; kw...) = Leaf(adjust(ℓ.rule; kw...), ℓ.state, ℓ.frozen)
112108

113109

114110
"""

test/runtests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ end
232232
st = Optimisers.adjust(st, 0.2)
233233
Optimisers.thaw!(st)
234234
st, m = Optimisers.update(st, m, (x=[1,10], y=([100,1000], nothing)));
235-
@test m.y[1] [-7.0, -96.0]
235+
@test m.y[1] [-17.0, -196.0]
236236
@test m.x [0.7, -1.0]
237237

238238
@test_throws ArgumentError Optimisers.freeze!(m)

0 commit comments

Comments
 (0)