diff --git a/README.md b/README.md index 1dec5940e7..b3dda36a0f 100644 --- a/README.md +++ b/README.md @@ -18,23 +18,23 @@ Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable. -Works best with [Julia 1.8](https://julialang.org/downloads/) or later. Here's a simple example to try it out: +Works best with [Julia 1.8](https://julialang.org/downloads/) or later. Here's a very short example to try it out: ```julia -using Flux # should install everything for you, including CUDA +using Flux, Plots +data = [([x], 2x-x^3) for x in -2:0.1f0:2] -x = hcat(digits.(0:3, base=2, pad=2)...) |> gpu # let's solve the XOR problem! -y = Flux.onehotbatch(xor.(eachrow(x)...), 0:1) |> gpu -data = ((Float32.(x), y) for _ in 1:100) # an iterator making Tuples +model = Chain(Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only) -model = Chain(Dense(2 => 3, sigmoid), BatchNorm(3), Dense(3 => 2)) |> gpu -optim = Adam(0.1, (0.7, 0.95)) -mloss(x, y) = Flux.logitcrossentropy(model(x), y) # closes over model +mloss(x,y) = (model(x) - y)^2 +optim = Flux.Adam() +for epoch in 1:1000 + Flux.train!(mloss, Flux.params(model), data, optim) +end -Flux.train!(mloss, Flux.params(model), data, optim) # updates model & optim - -all((softmax(model(x)) .> 0.5) .== y) # usually 100% accuracy. +plot(x -> 2x-x^3, -2, 2, legend=false) +scatter!(-2:0.1:2, [model([x]) for x in -2:0.1:2]) ``` -See the [documentation](https://fluxml.github.io/Flux.jl/) for details, or the [model zoo](https://github.com/FluxML/model-zoo/) for examples. Ask questions on the [Julia discourse](https://discourse.julialang.org/) or [slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866). +The [quickstart page](https://fluxml.ai/Flux.jl/stable/models/quickstart/) has a longer example. See the [documentation](https://fluxml.github.io/Flux.jl/) for details, or the [model zoo](https://github.com/FluxML/model-zoo/) for examples. Ask questions on the [Julia discourse](https://discourse.julialang.org/) or [slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866). If you use Flux in your research, please [cite](CITATION.bib) our work. diff --git a/docs/src/assets/quickstart/loss.png b/docs/src/assets/quickstart/loss.png new file mode 100644 index 0000000000..8cfa5523d2 Binary files /dev/null and b/docs/src/assets/quickstart/loss.png differ diff --git a/docs/src/assets/oneminute.png b/docs/src/assets/quickstart/oneminute.png similarity index 100% rename from docs/src/assets/oneminute.png rename to docs/src/assets/quickstart/oneminute.png diff --git a/docs/src/models/quickstart.md b/docs/src/models/quickstart.md index 28685176e9..f93be4de95 100644 --- a/docs/src/models/quickstart.md +++ b/docs/src/models/quickstart.md @@ -6,45 +6,54 @@ If you haven't, then you might prefer the [Fitting a Straight Line](overview.md) ```julia # With Julia 1.7+, this will prompt if neccessary to install everything, including CUDA: -using Flux, Statistics +using Flux, Statistics, ProgressMeter # Generate some data for the XOR problem: vectors of length 2, as columns of a matrix: noisy = rand(Float32, 2, 1000) # 2×1000 Matrix{Float32} -truth = map(col -> xor(col...), eachcol(noisy .> 0.5)) # 1000-element Vector{Bool} +truth = [xor(col[1]>0.5, col[2]>0.5) for col in eachcol(noisy)] # 1000-element Vector{Bool} # Define our model, a multi-layer perceptron with one hidden layer of size 3: -model = Chain(Dense(2 => 3, tanh), BatchNorm(3), Dense(3 => 2), softmax) +model = Chain( + Dense(2 => 3, tanh), # activation function inside layer + BatchNorm(3), + Dense(3 => 2), + softmax) |> gpu # move model to GPU, if available # The model encapsulates parameters, randomly initialised. Its initial output is: -out1 = model(noisy) # 2×1000 Matrix{Float32} +out1 = model(noisy |> gpu) |> cpu # 2×1000 Matrix{Float32} -# To train the model, we use batches of 64 samples: -mat = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix -data = Flux.DataLoader((noisy, mat), batchsize=64, shuffle=true); -first(data) .|> summary # ("2×64 Matrix{Float32}", "2×64 Matrix{Bool}") +# To train the model, we use batches of 64 samples, and one-hot encoding: +target = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix +loader = Flux.DataLoader((noisy, target) |> gpu, batchsize=64, shuffle=true); +# 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix) pars = Flux.params(model) # contains references to arrays in model opt = Flux.Adam(0.01) # will store optimiser momentum, etc. # Training loop, using the whole data set 1000 times: -for epoch in 1:1_000 - Flux.train!(pars, data, opt) do x, y - # First argument of train! is a loss function, here defined by a `do` block. - # This gets x and y, each a 2×64 Matrix, from data, and compares: - Flux.crossentropy(model(x), y) +losses = [] +@showprogress for epoch in 1:1_000 + for (x, y) in loader + loss, grad = Flux.withgradient(pars) do + # Evaluate model and loss inside gradient context: + y_hat = model(x) + Flux.crossentropy(y_hat, y) + end + Flux.update!(opt, pars, grad) + push!(losses, loss) # logging, outside gradient context end end -pars # has changed! +pars # parameters, momenta and output have all changed opt -out2 = model(noisy) +out2 = model(noisy |> gpu) |> cpu # first row is prob. of true, second row p(false) mean((out2[1,:] .> 0.5) .== truth) # accuracy 94% so far! ``` -![](../assets/oneminute.png) +![](../assets/quickstart/oneminute.png) -``` +```julia using Plots # to draw the above figure p_true = scatter(noisy[1,:], noisy[2,:], zcolor=truth, title="True classification", legend=false) @@ -54,26 +63,43 @@ p_done = scatter(noisy[1,:], noisy[2,:], zcolor=out2[1,:], title="Trained networ plot(p_true, p_raw, p_done, layout=(1,3), size=(1000,330)) ``` +```@raw html + +``` + +Here's the loss during training: + +```julia +plot(losses; xaxis=(:log10, "iteration"), + yaxis="loss", label="per batch") +n = length(loader) +plot!(n:n:length(losses), mean.(Iterators.partition(losses, n)), + label="epoch mean", dpi=200) +``` + This XOR ("exclusive or") problem is a variant of the famous one which drove Minsky and Papert to invent deep neural networks in 1969. For small values of "deep" -- this has one hidden layer, while earlier perceptrons had none. (What they call a hidden layer, Flux calls the output of the first layer, `model[1](noisy)`.) Since then things have developed a little. -## Features of Note +## Features to Note Some things to notice in this example are: -* The batch dimension of data is always the last one. Thus a `2×1000 Matrix` is a thousand observations, each a column of length 2. - -* The `model` can be called like a function, `y = model(x)`. It encapsulates the parameters (and state). +* The batch dimension of data is always the last one. Thus a `2×1000 Matrix` is a thousand observations, each a column of length 2. Flux defaults to `Float32`, but most of Julia to `Float64`. -* But the model does not contain the loss function, nor the optimisation rule. Instead the [`Adam()`](@ref Flux.Adam) object stores between iterations the momenta it needs. +* The `model` can be called like a function, `y = model(x)`. Each layer like [`Dense`](@ref Flux.Dense) is an ordinary `struct`, which encapsulates some arrays of parameters (and possibly other state, as for [`BatchNorm`](@ref Flux.BatchNorm)). -* The function [`train!`](@ref Flux.train!) likes data as an iterator generating `Tuple`s, here produced by [`DataLoader`](@ref). This mutates both the `model` and the optimiser state inside `opt`. +* But the model does not contain the loss function, nor the optimisation rule. The [`Adam`](@ref Flux.Adam) object stores between iterations the momenta it needs. And [`Flux.crossentropy`](@ref Flux.Losses.crossentropy) is an ordinary function. -There are other ways to train Flux models, for more control than `train!` provides: +* The `do` block creates an anonymous function, as the first argument of `gradient`. Anything executed within this is differentiated. -* Within Flux, you can easily write a training loop, calling [`gradient`](@ref) and [`update!`](@ref Flux.update!). +Instead of calling [`gradient`](@ref Zygote.gradient) and [`update!`](@ref Flux.update!) separately, there is a convenience function [`train!`](@ref Flux.train!). If we didn't want anything extra (like logging the loss), we could replace the training loop with the following: -* For a lower-level way, see the package [Optimisers.jl](https://github.com/FluxML/Optimisers.jl). - -* For higher-level ways, see [FluxTraining.jl](https://github.com/FluxML/FluxTraining.jl) and [FastAI.jl](https://github.com/FluxML/FastAI.jl). +```julia +for epoch in 1:1_000 + train!(pars, loader, opt) do x, y + y_hat = model(x) + Flux.crossentropy(y_hat, y) + end +end +```