|
| 1 | +## Classification of MNIST dataset |
| 2 | +## with the convolutional neural network known as LeNet5. |
| 3 | +## This script also combines various |
| 4 | +## packages from the Julia ecosystem with Flux. |
| 5 | +using Flux |
| 6 | +using Flux.Data: DataLoader |
| 7 | +using Flux.Optimise: Optimiser, WeightDecay |
| 8 | +using Flux: onehotbatch, onecold, glorot_normal, label_smoothing |
| 9 | +using Flux.Losses: logitcrossentropy |
| 10 | +using Statistics, Random |
| 11 | +using Logging: with_logger |
| 12 | +using TensorBoardLogger: TBLogger, tb_overwrite, set_step!, set_step_increment! |
| 13 | +using ProgressMeter: @showprogress |
| 14 | +import MLDatasets |
| 15 | +import BSON |
| 16 | +using CUDA |
| 17 | +using Formatting |
| 18 | + |
| 19 | +using DeepUncertainty |
| 20 | + |
| 21 | +# LeNet5 "constructor". |
| 22 | +# The model can be adapted to any image size |
| 23 | +# and any number of output classes. |
| 24 | +function LeNet5(args; imgsize = (28, 28, 1), nclasses = 10) |
| 25 | + out_conv_size = (imgsize[1] ÷ 4 - 3, imgsize[2] ÷ 4 - 3, 16) |
| 26 | + |
| 27 | + return Chain( |
| 28 | + ConvBatchEnsemble((5, 5), imgsize[end] => 6, args.rank, args.ensemble_size, relu), |
| 29 | + MaxPool((2, 2)), |
| 30 | + ConvBatchEnsemble((5, 5), 6 => 16, args.rank, args.ensemble_size, relu), |
| 31 | + MaxPool((2, 2)), |
| 32 | + flatten, |
| 33 | + DenseBatchEnsemble(prod(out_conv_size), 120, args.rank, args.ensemble_size, relu), |
| 34 | + DenseBatchEnsemble(120, 84, args.rank, args.ensemble_size, relu), |
| 35 | + DenseBatchEnsemble(84, nclasses, args.rank, args.ensemble_size), |
| 36 | + ) |
| 37 | +end |
| 38 | + |
| 39 | +function get_data(args) |
| 40 | + xtrain, ytrain = MLDatasets.MNIST.traindata(Float32) |
| 41 | + xtest, ytest = MLDatasets.MNIST.testdata(Float32) |
| 42 | + |
| 43 | + xtrain = reshape(xtrain, 28, 28, 1, :) |
| 44 | + xtest = reshape(xtest, 28, 28, 1, :) |
| 45 | + |
| 46 | + ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9) |
| 47 | + |
| 48 | + train_loader = DataLoader( |
| 49 | + (xtrain, ytrain), |
| 50 | + batchsize = args.batchsize, |
| 51 | + shuffle = true, |
| 52 | + partial = false, |
| 53 | + ) |
| 54 | + test_loader = DataLoader((xtest, ytest), batchsize = args.batchsize, partial = false) |
| 55 | + |
| 56 | + return train_loader, test_loader |
| 57 | +end |
| 58 | + |
| 59 | +loss(ŷ, y) = logitcrossentropy(ŷ, y) |
| 60 | + |
| 61 | +function accuracy(preds, labels) |
| 62 | + acc = sum(onecold(preds |> cpu) .== onecold(labels |> cpu)) |
| 63 | + return acc |
| 64 | +end |
| 65 | + |
| 66 | +function eval_loss_accuracy(args, loader, model, device) |
| 67 | + l = [0.0f0 for x = 1:args.ensemble_size] |
| 68 | + acc = [0 for x = 1:args.ensemble_size] |
| 69 | + ece_list = [0.0f0 for x = 1:args.ensemble_size] |
| 70 | + ntot = 0 |
| 71 | + mean_l = 0 |
| 72 | + mean_acc = 0 |
| 73 | + mean_ece = 0 |
| 74 | + for (x, y) in loader |
| 75 | + x = repeat(x, 1, 1, 1, args.ensemble_size) |
| 76 | + x, y = x |> device, y |> device |
| 77 | + # Perform the forward pass |
| 78 | + ŷ = model(x) |
| 79 | + ŷ = softmax(ŷ, dims = 1) |
| 80 | + # Reshape the predictions into [classes, batch_size, ensemble_size |
| 81 | + reshaped_ŷ = reshape(ŷ, size(ŷ)[1], args.batchsize, args.ensemble_size) |
| 82 | + # Loop through each model's predictions |
| 83 | + for ensemble = 1:args.ensemble_size |
| 84 | + model_predictions = reshaped_ŷ[:, :, ensemble] |
| 85 | + # Calculate individual loss |
| 86 | + l[ensemble] += loss(model_predictions, y) * size(model_predictions)[end] |
| 87 | + acc[ensemble] += accuracy(model_predictions, y) |
| 88 | + ece_list[ensemble] += |
| 89 | + ExpectedCalibrationError(model_predictions |> cpu, onecold(y |> cpu)) * |
| 90 | + args.batchsize |
| 91 | + end |
| 92 | + # Get the mean predictions |
| 93 | + mean_predictions = mean(reshaped_ŷ, dims = ndims(reshaped_ŷ)) |
| 94 | + mean_predictions = dropdims(mean_predictions, dims = ndims(mean_predictions)) |
| 95 | + mean_l += loss(mean_predictions, y) * size(mean_predictions)[end] |
| 96 | + mean_acc += accuracy(mean_predictions, y) |
| 97 | + mean_ece += |
| 98 | + ExpectedCalibrationError(mean_predictions |> cpu, onecold(y |> cpu)) * |
| 99 | + args.batchsize |
| 100 | + ntot += size(mean_predictions)[end] |
| 101 | + end |
| 102 | + # Normalize the loss |
| 103 | + losses = [loss / ntot |> round4 for loss in l] |
| 104 | + acc = [a / ntot * 100 |> round4 for a in acc] |
| 105 | + ece_list = [x / ntot |> round4 for x in ece_list] |
| 106 | + # Calculate mean loss |
| 107 | + mean_l = mean_l / ntot |> round4 |
| 108 | + mean_acc = mean_acc / ntot * 100 |> round4 |
| 109 | + mean_ece = mean_ece / ntot |> round4 |
| 110 | + |
| 111 | + # Print the per ensemble mode loss and accuracy |
| 112 | + for ensemble = 1:args.ensemble_size |
| 113 | + @info (format( |
| 114 | + "Model {} Loss: {} Accuracy: {} ECE: {}", |
| 115 | + ensemble, |
| 116 | + losses[ensemble], |
| 117 | + acc[ensemble], |
| 118 | + ece_list[ensemble], |
| 119 | + )) |
| 120 | + end |
| 121 | + @info (format( |
| 122 | + "Mean Loss: {} Mean Accuracy: {} Mean ECE: {}", |
| 123 | + mean_l, |
| 124 | + mean_acc, |
| 125 | + mean_ece, |
| 126 | + )) |
| 127 | + @info "===========================================================" |
| 128 | + return nothing |
| 129 | +end |
| 130 | + |
| 131 | +## utility functions |
| 132 | +num_params(model) = sum(length, Flux.params(model)) |
| 133 | +round4(x) = round(x, digits = 4) |
| 134 | + |
| 135 | +# arguments for the `train` function |
| 136 | +Base.@kwdef mutable struct Args |
| 137 | + η = 3e-4 # learning rate |
| 138 | + λ = 0 # L2 regularizer param, implemented as weight decay |
| 139 | + batchsize = 32 # batch size |
| 140 | + epochs = 10 # number of epochs |
| 141 | + seed = 0 # set seed > 0 for reproducibility |
| 142 | + use_cuda = true # if true use cuda (if available) |
| 143 | + infotime = 1 # report every `infotime` epochs |
| 144 | + checktime = 5 # Save the model every `checktime` epochs. Set to 0 for no checkpoints. |
| 145 | + savepath = "runs/" # results path |
| 146 | + rank = 1 |
| 147 | + ensemble_size = 4 |
| 148 | +end |
| 149 | + |
| 150 | +function train(; kws...) |
| 151 | + args = Args(; kws...) |
| 152 | + args.seed > 0 && Random.seed!(args.seed) |
| 153 | + use_cuda = args.use_cuda && CUDA.functional() |
| 154 | + |
| 155 | + if use_cuda |
| 156 | + device = gpu |
| 157 | + @info "Training on GPU" |
| 158 | + else |
| 159 | + device = cpu |
| 160 | + @info "Training on CPU" |
| 161 | + end |
| 162 | + |
| 163 | + ## DATA |
| 164 | + train_loader, test_loader = get_data(args) |
| 165 | + @info "Dataset MNIST: $(train_loader.nobs) train and $(test_loader.nobs) test examples" |
| 166 | + |
| 167 | + ## MODEL AND OPTIMIZER |
| 168 | + model = LeNet5(args) |> device |
| 169 | + @info "LeNet5 model: $(num_params(model)) trainable params" |
| 170 | + |
| 171 | + ps = Flux.params(model) |
| 172 | + |
| 173 | + opt = ADAM(args.η) |
| 174 | + if args.λ > 0 # add weight decay, equivalent to L2 regularization |
| 175 | + opt = Optimiser(WeightDecay(args.λ), opt) |
| 176 | + end |
| 177 | + |
| 178 | + function report(epoch) |
| 179 | + # @info "Train Metrics" |
| 180 | + # eval_loss_accuracy(args, train_loader, model, device) |
| 181 | + @info "Test metrics" |
| 182 | + eval_loss_accuracy(args, test_loader, model, device) |
| 183 | + end |
| 184 | + |
| 185 | + ## TRAINING |
| 186 | + @info "Start Training" |
| 187 | + report(0) |
| 188 | + for epoch = 1:args.epochs |
| 189 | + @showprogress for (x, y) in train_loader |
| 190 | + # Make copies of batches for ensembles |
| 191 | + x = repeat(x, 1, 1, 1, args.ensemble_size) |
| 192 | + y = repeat(y, 1, args.ensemble_size) |
| 193 | + x, y = x |> device, y |> device |
| 194 | + gs = Flux.gradient(ps) do |
| 195 | + ŷ = model(x) |
| 196 | + loss(ŷ, y) |
| 197 | + end |
| 198 | + |
| 199 | + Flux.Optimise.update!(opt, ps, gs) |
| 200 | + end |
| 201 | + |
| 202 | + ## Printing and logging |
| 203 | + epoch % args.infotime == 0 && report(epoch) |
| 204 | + end |
| 205 | +end |
| 206 | + |
| 207 | +train() |
0 commit comments