Skip to content

Commit ed3d16c

Browse files
authored
Added Dense and Conv BatchEnsemble layers along with unit tests and example on MNIST classification using LeNet5 (#4)
* Added Dense and Conv BatchEnsemble layers along with unit tests and example on MNIST classification using LeNet5 * Merged dense batchensemble forward passes for rank>1 and rank=1; Fixed conv batchensemble unit test * Changes: 1. Reduce imports and move them to main file 2. Renamed test file names 3. Added GPU tests for layers -- for now it's basic forward pass etc
1 parent b11fbea commit ed3d16c

File tree

11 files changed

+728
-18
lines changed

11 files changed

+728
-18
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1010
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
ReliabilityDiagrams = "e5f51471-6270-49e4-a15a-f1cfbff4f856"
1212
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
13+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1314

1415
[compat]
1516
julia = "1"

examples/batchensemble.jl

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
## Classification of MNIST dataset
2+
## with the convolutional neural network known as LeNet5.
3+
## This script also combines various
4+
## packages from the Julia ecosystem with Flux.
5+
using Flux
6+
using Flux.Data: DataLoader
7+
using Flux.Optimise: Optimiser, WeightDecay
8+
using Flux: onehotbatch, onecold, glorot_normal, label_smoothing
9+
using Flux.Losses: logitcrossentropy
10+
using Statistics, Random
11+
using Logging: with_logger
12+
using TensorBoardLogger: TBLogger, tb_overwrite, set_step!, set_step_increment!
13+
using ProgressMeter: @showprogress
14+
import MLDatasets
15+
import BSON
16+
using CUDA
17+
using Formatting
18+
19+
using DeepUncertainty
20+
21+
# LeNet5 "constructor".
22+
# The model can be adapted to any image size
23+
# and any number of output classes.
24+
function LeNet5(args; imgsize = (28, 28, 1), nclasses = 10)
25+
out_conv_size = (imgsize[1] ÷ 4 - 3, imgsize[2] ÷ 4 - 3, 16)
26+
27+
return Chain(
28+
ConvBatchEnsemble((5, 5), imgsize[end] => 6, args.rank, args.ensemble_size, relu),
29+
MaxPool((2, 2)),
30+
ConvBatchEnsemble((5, 5), 6 => 16, args.rank, args.ensemble_size, relu),
31+
MaxPool((2, 2)),
32+
flatten,
33+
DenseBatchEnsemble(prod(out_conv_size), 120, args.rank, args.ensemble_size, relu),
34+
DenseBatchEnsemble(120, 84, args.rank, args.ensemble_size, relu),
35+
DenseBatchEnsemble(84, nclasses, args.rank, args.ensemble_size),
36+
)
37+
end
38+
39+
function get_data(args)
40+
xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
41+
xtest, ytest = MLDatasets.MNIST.testdata(Float32)
42+
43+
xtrain = reshape(xtrain, 28, 28, 1, :)
44+
xtest = reshape(xtest, 28, 28, 1, :)
45+
46+
ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)
47+
48+
train_loader = DataLoader(
49+
(xtrain, ytrain),
50+
batchsize = args.batchsize,
51+
shuffle = true,
52+
partial = false,
53+
)
54+
test_loader = DataLoader((xtest, ytest), batchsize = args.batchsize, partial = false)
55+
56+
return train_loader, test_loader
57+
end
58+
59+
loss(ŷ, y) = logitcrossentropy(ŷ, y)
60+
61+
function accuracy(preds, labels)
62+
acc = sum(onecold(preds |> cpu) .== onecold(labels |> cpu))
63+
return acc
64+
end
65+
66+
function eval_loss_accuracy(args, loader, model, device)
67+
l = [0.0f0 for x = 1:args.ensemble_size]
68+
acc = [0 for x = 1:args.ensemble_size]
69+
ece_list = [0.0f0 for x = 1:args.ensemble_size]
70+
ntot = 0
71+
mean_l = 0
72+
mean_acc = 0
73+
mean_ece = 0
74+
for (x, y) in loader
75+
x = repeat(x, 1, 1, 1, args.ensemble_size)
76+
x, y = x |> device, y |> device
77+
# Perform the forward pass
78+
= model(x)
79+
= softmax(ŷ, dims = 1)
80+
# Reshape the predictions into [classes, batch_size, ensemble_size
81+
reshaped_ŷ = reshape(ŷ, size(ŷ)[1], args.batchsize, args.ensemble_size)
82+
# Loop through each model's predictions
83+
for ensemble = 1:args.ensemble_size
84+
model_predictions = reshaped_ŷ[:, :, ensemble]
85+
# Calculate individual loss
86+
l[ensemble] += loss(model_predictions, y) * size(model_predictions)[end]
87+
acc[ensemble] += accuracy(model_predictions, y)
88+
ece_list[ensemble] +=
89+
ExpectedCalibrationError(model_predictions |> cpu, onecold(y |> cpu)) *
90+
args.batchsize
91+
end
92+
# Get the mean predictions
93+
mean_predictions = mean(reshaped_ŷ, dims = ndims(reshaped_ŷ))
94+
mean_predictions = dropdims(mean_predictions, dims = ndims(mean_predictions))
95+
mean_l += loss(mean_predictions, y) * size(mean_predictions)[end]
96+
mean_acc += accuracy(mean_predictions, y)
97+
mean_ece +=
98+
ExpectedCalibrationError(mean_predictions |> cpu, onecold(y |> cpu)) *
99+
args.batchsize
100+
ntot += size(mean_predictions)[end]
101+
end
102+
# Normalize the loss
103+
losses = [loss / ntot |> round4 for loss in l]
104+
acc = [a / ntot * 100 |> round4 for a in acc]
105+
ece_list = [x / ntot |> round4 for x in ece_list]
106+
# Calculate mean loss
107+
mean_l = mean_l / ntot |> round4
108+
mean_acc = mean_acc / ntot * 100 |> round4
109+
mean_ece = mean_ece / ntot |> round4
110+
111+
# Print the per ensemble mode loss and accuracy
112+
for ensemble = 1:args.ensemble_size
113+
@info (format(
114+
"Model {} Loss: {} Accuracy: {} ECE: {}",
115+
ensemble,
116+
losses[ensemble],
117+
acc[ensemble],
118+
ece_list[ensemble],
119+
))
120+
end
121+
@info (format(
122+
"Mean Loss: {} Mean Accuracy: {} Mean ECE: {}",
123+
mean_l,
124+
mean_acc,
125+
mean_ece,
126+
))
127+
@info "==========================================================="
128+
return nothing
129+
end
130+
131+
## utility functions
132+
num_params(model) = sum(length, Flux.params(model))
133+
round4(x) = round(x, digits = 4)
134+
135+
# arguments for the `train` function
136+
Base.@kwdef mutable struct Args
137+
η = 3e-4 # learning rate
138+
λ = 0 # L2 regularizer param, implemented as weight decay
139+
batchsize = 32 # batch size
140+
epochs = 10 # number of epochs
141+
seed = 0 # set seed > 0 for reproducibility
142+
use_cuda = true # if true use cuda (if available)
143+
infotime = 1 # report every `infotime` epochs
144+
checktime = 5 # Save the model every `checktime` epochs. Set to 0 for no checkpoints.
145+
savepath = "runs/" # results path
146+
rank = 1
147+
ensemble_size = 4
148+
end
149+
150+
function train(; kws...)
151+
args = Args(; kws...)
152+
args.seed > 0 && Random.seed!(args.seed)
153+
use_cuda = args.use_cuda && CUDA.functional()
154+
155+
if use_cuda
156+
device = gpu
157+
@info "Training on GPU"
158+
else
159+
device = cpu
160+
@info "Training on CPU"
161+
end
162+
163+
## DATA
164+
train_loader, test_loader = get_data(args)
165+
@info "Dataset MNIST: $(train_loader.nobs) train and $(test_loader.nobs) test examples"
166+
167+
## MODEL AND OPTIMIZER
168+
model = LeNet5(args) |> device
169+
@info "LeNet5 model: $(num_params(model)) trainable params"
170+
171+
ps = Flux.params(model)
172+
173+
opt = ADAM(args.η)
174+
if args.λ > 0 # add weight decay, equivalent to L2 regularization
175+
opt = Optimiser(WeightDecay(args.λ), opt)
176+
end
177+
178+
function report(epoch)
179+
# @info "Train Metrics"
180+
# eval_loss_accuracy(args, train_loader, model, device)
181+
@info "Test metrics"
182+
eval_loss_accuracy(args, test_loader, model, device)
183+
end
184+
185+
## TRAINING
186+
@info "Start Training"
187+
report(0)
188+
for epoch = 1:args.epochs
189+
@showprogress for (x, y) in train_loader
190+
# Make copies of batches for ensembles
191+
x = repeat(x, 1, 1, 1, args.ensemble_size)
192+
y = repeat(y, 1, args.ensemble_size)
193+
x, y = x |> device, y |> device
194+
gs = Flux.gradient(ps) do
195+
= model(x)
196+
loss(ŷ, y)
197+
end
198+
199+
Flux.Optimise.update!(opt, ps, gs)
200+
end
201+
202+
## Printing and logging
203+
epoch % args.infotime == 0 && report(epoch)
204+
end
205+
end
206+
207+
train()

src/DeepUncertainty.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
module DeepUncertainty
22

3+
using Flux
4+
using Random
5+
using Flux: @functor, glorot_normal, create_bias
6+
37
# Export layers
48
export MCLayer, MCDense, MCConv
9+
export DenseBatchEnsemble, ConvBatchEnsemble
510
export mean_loglikelihood, brier_score, ExpectedCalibrationError, prediction_metrics
611

712
include("metrics.jl")
813
include("layers/mclayers.jl")
14+
include("layers/BatchEnsemble/dense.jl")
15+
include("layers/BatchEnsemble/conv.jl")
916

1017
end

src/layers/BatchEnsemble/conv.jl

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""
2+
ConvBatchEnsemble(filter, in => out, rank,
3+
ensemble_size, σ = identity;
4+
stride = 1, pad = 0, dilation = 1,
5+
groups = 1, [bias, weight, init])
6+
ConvBatchEnsemble(layer, alpha, gamma, ensemble_bias, ensemble_act, rank)
7+
8+
Creates a conv BatchEnsemble layer. Batch ensemble is a memory efficient alternative
9+
for deep ensembles. In deep ensembles, if the ensemble size is N, N different models
10+
are trained, making the time and memory complexity O(N * complexity of one network).
11+
BatchEnsemble generates weight matrices for each member in the ensemble using a
12+
couple of rank 1 vectors R (alpha), S (gamma), RS' and multiplying the result with
13+
weight matrix W element wise. We also call R and S as fast weights.
14+
15+
Reference - https://arxiv.org/abs/2002.06715
16+
17+
During both training and testing, we repeat the samples along the batch dimension
18+
N times, where N is the ensemble_size. For example, if each mini batch has 10 samples
19+
and our ensemble size is 4, then the actual input to the layer has 40 samples.
20+
The output of the layer has 40 samples as well, and each 10 samples can be considered
21+
as the output of an esnemble member.
22+
23+
# Fields
24+
- `layer`: The dense layer which transforms the pertubed input to output
25+
- `alpha`: The first Fast weight of size (in_dim, ensemble_size)
26+
- `gamma`: The second Fast weight of size (out_dim, ensemble_size)
27+
- `ensemble_bias`: Bias added to the ensemble output, separate from dense layer bias
28+
- `ensemble_act`: The activation function to be applied on ensemble output
29+
- `rank`: Rank of the fast weights (rank > 1 doesn't work on GPU for now)
30+
31+
# Arguments
32+
- `filter::NTuple{N,Integer}`: Kernel dimensions, eg, (5, 5)
33+
- `ch::Pair{<:Integer,<:Integer}`: Input channels => output channels
34+
- `rank::Integer`: Rank of the fast weights
35+
- `ensemble_size::Integer`: Number of models in the ensemble
36+
- `σ::F=identity`: Activation of the dense layer, defaults to identity
37+
- `init=glorot_normal`: Initialization function, defaults to glorot_normal
38+
- `alpha_init=glorot_normal`: Initialization function for the alpha fast weight,
39+
defaults to glorot_normal
40+
- `gamma_init=glorot_normal`: Initialization function for the gamma fast weight,
41+
defaults to glorot_normal
42+
- `bias::Bool=true`: Toggle the usage of bias in the dense layer
43+
- `ensemble_bias::Bool=true`: Toggle the usage of ensemble bias
44+
- `ensemble_act::F=identity`: Activation function for enseble outputs
45+
"""
46+
struct ConvBatchEnsemble{L,F,M,B}
47+
layer::L
48+
alpha::M
49+
gamma::M
50+
ensemble_bias::B
51+
ensemble_act::F
52+
rank::Any
53+
function ConvBatchEnsemble(
54+
layer::L,
55+
alpha::M,
56+
gamma::M,
57+
ensemble_bias = true,
58+
ensemble_act::F = identity,
59+
rank = 1,
60+
) where {M,F,L}
61+
ensemble_bias = create_bias(gamma, ensemble_bias, size(gamma)[1], size(gamma)[2])
62+
new{typeof(layer),F,M,typeof(ensemble_bias)}(
63+
layer,
64+
alpha,
65+
gamma,
66+
ensemble_bias,
67+
ensemble_act,
68+
rank,
69+
)
70+
end
71+
end
72+
73+
function ConvBatchEnsemble(
74+
k::NTuple{N,Integer},
75+
ch::Pair{<:Integer,<:Integer},
76+
rank::Integer,
77+
ensemble_size::Integer,
78+
σ = identity;
79+
init = glorot_normal,
80+
alpha_init = glorot_normal,
81+
gamma_init = glorot_normal,
82+
stride = 1,
83+
pad = 0,
84+
dilation = 1,
85+
groups = 1,
86+
bias = true,
87+
ensemble_bias = true,
88+
ensemble_act = identity,
89+
) where {N}
90+
layer = Flux.Conv(
91+
k,
92+
ch,
93+
σ;
94+
stride = stride,
95+
pad = pad,
96+
dilation = dilation,
97+
init = init,
98+
groups = groups,
99+
bias = bias,
100+
)
101+
in_dim = ch[1]
102+
out_dim = ch[2]
103+
if rank >= 1
104+
alpha_shape = (in_dim, ensemble_size)
105+
gamma_shape = (out_dim, ensemble_size)
106+
else
107+
error("Rank must be >= 1.")
108+
end
109+
alpha = alpha_init(alpha_shape)
110+
gamma = gamma_init(gamma_shape)
111+
112+
return ConvBatchEnsemble(layer, alpha, gamma, ensemble_bias, ensemble_act, rank)
113+
end
114+
115+
@functor ConvBatchEnsemble
116+
117+
function (be::ConvBatchEnsemble)(x)
118+
# Conv Batch Ensemble params
119+
layer = be.layer
120+
alpha = be.alpha
121+
gamma = be.gamma
122+
e_b = be.ensemble_bias
123+
e_σ = be.ensemble_act
124+
125+
batch_size = size(x)[end]
126+
in_size = size(alpha)[1]
127+
out_size = size(gamma)[1]
128+
ensemble_size = size(alpha)[2]
129+
samples_per_model = batch_size ÷ ensemble_size
130+
131+
# Alpha, gamma shapes - [units, ensembles, rank]
132+
e_b = repeat(e_b, samples_per_model)
133+
alpha = repeat(alpha, samples_per_model)
134+
gamma = repeat(gamma, samples_per_model)
135+
# Reshape alpha, gamma to [units, batch_size, rank]
136+
e_b = reshape(e_b, (1, 1, out_size, batch_size))
137+
alpha = reshape(alpha, (1, 1, in_size, batch_size))
138+
gamma = reshape(gamma, (1, 1, out_size, batch_size))
139+
140+
perturbed_x = x .* alpha
141+
output = layer(perturbed_x) .* gamma
142+
output = e_σ.(output .+ e_b)
143+
144+
return output
145+
end

0 commit comments

Comments
 (0)