|
| 1 | +using Flux: withgradient, DataLoader |
| 2 | +using Optimisers: Optimisers |
| 3 | +using ProgressMeter: ProgressMeter, Progress, next! |
| 4 | + |
| 5 | +#= |
| 6 | +
|
| 7 | +This grew out of explicit-mode upgrade here: |
| 8 | +https://github.com/FluxML/Flux.jl/pull/2082 |
| 9 | +
|
| 10 | +=# |
| 11 | + |
| 12 | +""" |
| 13 | + shinkansen!(loss, model, data...; state, epochs=1, [batchsize, keywords...]) |
| 14 | +
|
| 15 | +This is a re-design of `train!`: |
| 16 | +
|
| 17 | +* The loss function must accept the remaining arguments: `loss(model, data...)` |
| 18 | +* The optimiser state from `setup` must be passed to the keyword `state`. |
| 19 | +
|
| 20 | +By default it calls `gradient(loss, model, data...)` just like that. |
| 21 | +Same order as the arguments. If you specify `epochs = 100`, then it will do this 100 times. |
| 22 | +
|
| 23 | +But if you specify `batchsize = 32`, then it first makes `DataLoader(data...; batchsize)`, |
| 24 | +and uses that to generate smaller arrays to feed to `gradient`. |
| 25 | +All other keywords are passed to `DataLoader`, e.g. to shuffle batches. |
| 26 | +
|
| 27 | +Returns the loss from every call. |
| 28 | +
|
| 29 | +# Example |
| 30 | +``` |
| 31 | +X = repeat(hcat(digits.(0:3, base=2, pad=2)...), 1, 32) |
| 32 | +Y = Flux.onehotbatch(xor.(eachrow(X)...), 0:1) |
| 33 | +
|
| 34 | +model = Chain(Dense(2 => 3, sigmoid), BatchNorm(3), Dense(3 => 2)) |
| 35 | +state = Flux.setup(Adam(0.1, (0.7, 0.95)), model) |
| 36 | +# state = Optimisers.setup(Optimisers.Adam(0.1, (0.7, 0.95)), model) # for now |
| 37 | +
|
| 38 | +shinkansen!(model, X, Y; state, epochs=100, batchsize=16, shuffle=true) do m, x, y |
| 39 | + Flux.logitcrossentropy(m(x), y) |
| 40 | +end |
| 41 | +
|
| 42 | +all((softmax(model(X)) .> 0.5) .== Y) |
| 43 | +``` |
| 44 | +""" |
| 45 | +function shinkansen!(loss::Function, model, data...; state, epochs=1, batchsize=nothing, kw...) |
| 46 | + if batchsize != nothing |
| 47 | + loader = DataLoader(data; batchsize, kw...) |
| 48 | + losses = Vector{Float32}[] |
| 49 | + prog = Progress(length(loader) * epochs) |
| 50 | + |
| 51 | + for e in 1:epochs |
| 52 | + eplosses = Float32[] |
| 53 | + for (i,d) in enumerate(loader) |
| 54 | + l, (g, _...) = withgradient(loss, model, d...) |
| 55 | + isfinite(l) || error("loss is $l, on batch $i, epoch $epoch") |
| 56 | + Optimisers.update!(state, model, g) |
| 57 | + push!(eplosses, l) |
| 58 | + next!(prog; showvalues=[(:epoch, e), (:loss, l)]) |
| 59 | + end |
| 60 | + push!(losses, eplosses) |
| 61 | + end |
| 62 | + |
| 63 | + return allequal(size.(losses)) ? reduce(hcat, losses) : losses |
| 64 | + else |
| 65 | + losses = Float32[] |
| 66 | + prog = Progress(epochs) |
| 67 | + |
| 68 | + for e in 1:epochs |
| 69 | + l, (g, _...) = withgradient(loss, model, data...) |
| 70 | + isfinite(l) || error("loss is $l, on epoch $epoch") |
| 71 | + Optimisers.update!(state, model, g) |
| 72 | + push!(losses, l) |
| 73 | + next!(prog; showvalues=[(:epoch, epoch), (:loss, l)]) |
| 74 | + end |
| 75 | + |
| 76 | + return losses |
| 77 | + end |
| 78 | +end |
0 commit comments