Skip to content

Commit 9f1daaf

Browse files
committed
add tests
1 parent 5cf2607 commit 9f1daaf

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

src/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ end
4646

4747
function Base.show(io::IO, ℓ::Leaf; colour =.frozen ? :cyan : :green)
4848
ioc = IOContext(io, :compact => true)
49-
str = sprint(show, ℓ.rule; context = ioc)
49+
str = sprint(show, ℓ.rule; context = ioc) # produces Adam{Float32}(0.001, ... not 0.001f0
5050
printstyled(io, "Leaf(", str, ", "; color = colour)
5151
show(ioc, ℓ.state)
5252
printstyled(io, ℓ.frozen ? ", frozen = true)" : ")"; color = colour)

test/runtests.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,24 @@ end
221221
@test sc2.γ.state[2][1] [0.1, 0.2, 0.2]
222222
end
223223

224+
@testset "freeze/thaw" begin
225+
m = (x=[1.0, 2.0], y=([3.0, 4.0], sin));
226+
st = Optimisers.setup(Descent(0.1), m);
227+
Optimisers.freeze!(st.y)
228+
st, m = Optimisers.update(st, m, (x=[1,10], y=([100,1000], nothing)));
229+
@test m.x [0.9, 1.0]
230+
@test m.y[1] == [3, 4]
231+
232+
st = Optimisers.adjust(st, 0.2)
233+
Optimisers.thaw!(st)
234+
st, m = Optimisers.update(st, m, (x=[1,10], y=([100,1000], nothing)));
235+
@test m.y[1] [-7.0, -96.0]
236+
@test m.x [0.7, -1.0]
237+
238+
@test_throws ArgumentError Optimisers.freeze!(m)
239+
@test_throws ArgumentError Optimisers.thaw!(m)
240+
end
241+
224242
@testset "forgotten gradient" begin
225243
x = [1.0, 2.0]
226244
sx = Optimisers.setup(Descent(), x)

0 commit comments

Comments
 (0)