Skip to content

Commit 8650d54

Browse files
authored
A cleaner more powerful train! function (#3)
1 parent 660ef33 commit 8650d54

File tree

5 files changed

+111
-1
lines changed

5 files changed

+111
-1
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ version = "0.1.0"
66
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
77
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
88
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
9+
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
910
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1011

1112
[compat]
1213
Flux = "0.13.7"
1314
NNlib = "0.8.10"
1415
Optimisers = "0.2.10"
16+
ProgressMeter = "1.7.2"
1517
Zygote = "0.6.49"
1618
julia = "1.6"
1719

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ As will any features which migrate to Flux itself.
3434
## Current Features
3535

3636
* Layers `Split` and `Join`
37-
37+
* A more advanced `train!`

src/Fluxperimental.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,7 @@ using Flux
55
include("split_join.jl")
66
export Split, Join
77

8+
include("train.jl")
9+
export shinkansen!
10+
811
end # module Fluxperimental

src/train.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
using Flux: withgradient, DataLoader
2+
using Optimisers: Optimisers
3+
using ProgressMeter: ProgressMeter, Progress, next!
4+
5+
#=
6+
7+
This grew out of explicit-mode upgrade here:
8+
https://github.com/FluxML/Flux.jl/pull/2082
9+
10+
=#
11+
12+
"""
13+
shinkansen!(loss, model, data...; state, epochs=1, [batchsize, keywords...])
14+
15+
This is a re-design of `train!`:
16+
17+
* The loss function must accept the remaining arguments: `loss(model, data...)`
18+
* The optimiser state from `setup` must be passed to the keyword `state`.
19+
20+
By default it calls `gradient(loss, model, data...)` just like that.
21+
Same order as the arguments. If you specify `epochs = 100`, then it will do this 100 times.
22+
23+
But if you specify `batchsize = 32`, then it first makes `DataLoader(data...; batchsize)`,
24+
and uses that to generate smaller arrays to feed to `gradient`.
25+
All other keywords are passed to `DataLoader`, e.g. to shuffle batches.
26+
27+
Returns the loss from every call.
28+
29+
# Example
30+
```
31+
X = repeat(hcat(digits.(0:3, base=2, pad=2)...), 1, 32)
32+
Y = Flux.onehotbatch(xor.(eachrow(X)...), 0:1)
33+
34+
model = Chain(Dense(2 => 3, sigmoid), BatchNorm(3), Dense(3 => 2))
35+
state = Flux.setup(Adam(0.1, (0.7, 0.95)), model)
36+
# state = Optimisers.setup(Optimisers.Adam(0.1, (0.7, 0.95)), model) # for now
37+
38+
shinkansen!(model, X, Y; state, epochs=100, batchsize=16, shuffle=true) do m, x, y
39+
Flux.logitcrossentropy(m(x), y)
40+
end
41+
42+
all((softmax(model(X)) .> 0.5) .== Y)
43+
```
44+
"""
45+
function shinkansen!(loss::Function, model, data...; state, epochs=1, batchsize=nothing, kw...)
46+
if batchsize != nothing
47+
loader = DataLoader(data; batchsize, kw...)
48+
losses = Vector{Float32}[]
49+
prog = Progress(length(loader) * epochs)
50+
51+
for e in 1:epochs
52+
eplosses = Float32[]
53+
for (i,d) in enumerate(loader)
54+
l, (g, _...) = withgradient(loss, model, d...)
55+
isfinite(l) || error("loss is $l, on batch $i, epoch $epoch")
56+
Optimisers.update!(state, model, g)
57+
push!(eplosses, l)
58+
next!(prog; showvalues=[(:epoch, e), (:loss, l)])
59+
end
60+
push!(losses, eplosses)
61+
end
62+
63+
return allequal(size.(losses)) ? reduce(hcat, losses) : losses
64+
else
65+
losses = Float32[]
66+
prog = Progress(epochs)
67+
68+
for e in 1:epochs
69+
l, (g, _...) = withgradient(loss, model, data...)
70+
isfinite(l) || error("loss is $l, on epoch $epoch")
71+
Optimisers.update!(state, model, g)
72+
push!(losses, l)
73+
next!(prog; showvalues=[(:epoch, epoch), (:loss, l)])
74+
end
75+
76+
return losses
77+
end
78+
end

test/train.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import Flux, Fluxperimental, Optimisers
2+
3+
@testset "shinkansen!" begin
4+
5+
X = repeat(hcat(digits.(0:3, base=2, pad=2)...), 1, 32)
6+
Y = Flux.onehotbatch(xor.(eachrow(X)...), 0:1)
7+
8+
model = Flux.Chain(Flux.Dense(2 => 3, Flux.sigmoid), Flux.BatchNorm(3), Flux.Dense(3 => 2))
9+
state = Optimisers.setup(Optimisers.Adam(0.1, (0.7, 0.95)), model)
10+
11+
Fluxperimental.shinkansen!(model, X, Y; state, epochs=100) do m, x, y
12+
Flux.logitcrossentropy(m(x), y)
13+
end
14+
15+
@test all((Flux.softmax(model(X)) .> 0.5) .== Y)
16+
17+
model = Flux.Chain(Flux.Dense(2 => 3, Flux.sigmoid), Flux.BatchNorm(3), Flux.Dense(3 => 2))
18+
state = Optimisers.setup(Optimisers.Adam(0.1, (0.7, 0.95)), model)
19+
20+
Fluxperimental.shinkansen!(model, X, Y; state, epochs=100, batchsize=16, shuffle=true) do m, x, y
21+
Flux.logitcrossentropy(m(x), y)
22+
end
23+
24+
@test all((Flux.softmax(model(X)) .> 0.5) .== Y)
25+
26+
end
27+

0 commit comments

Comments
 (0)