Skip to content

Commit 3469eb6

Browse files
authored
Merge pull request #158 from Dsantra92/fix/dcgan
Fix the links for https://fluxml.ai/tutorialposts/2021-10-08-dcgan-mnist/
2 parents d6cbc78 + c33713d commit 3469eb6

File tree

1 file changed

+17
-23
lines changed

1 file changed

+17
-23
lines changed

tutorialposts/2021-10-08-dcgan-mnist.md

+17-23
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ This is a beginner level tutorial for generating images of handwritten digits us
1111

1212
A GAN is composed of two sub-models - the **generator** and the **discriminator** acting against one another. The generator can be considered as an artist who draws (generates) new images that look real, whereas the discriminator is a critic who learns to tell real images apart from fakes.
1313

14-
![](../../assets/tutorialposts/2021-10-8-dcgan-mnist/cat_gan.png)
14+
![](../../assets/tutorialposts/2021-10-08-dcgan-mnist/cat_gan.png)
1515

1616
The GAN starts with a generator and discriminator which have very little or no idea about the underlying data. During training, the generator progressively becomes better at creating images that look real, while the discriminator becomes better at telling them apart. The process reaches equilibrium when the discriminator can no longer distinguish real images from fakes.
1717

@@ -24,9 +24,9 @@ This tutorial demonstrates the process of training a DC-GAN on the [MNIST datase
2424

2525
~~~
2626
<br><br>
27-
<p align="center">
28-
<img src="../../assets/tutorialposts/2021-10-8-dcgan-mnist/output.gif" align="middle" width="200">
29-
</p>
27+
<div style="text-align:center">
28+
<img src="../../assets/tutorialposts/2021-10-08-dcgan-mnist/output.gif" width="200">
29+
</div>
3030
~~~
3131

3232
## Setup
@@ -43,7 +43,7 @@ Pkg.add(["Images", "Flux", "MLDatasets", "CUDA", "Parameters"])
4343
```
4444
*Note: Depending on your internet speed, it may take a few minutes for the packages install.*
4545

46-
<br>
46+
\
4747
After installing the libraries, load the required packages and functions:
4848
```julia
4949
using Base.Iterators: partition
@@ -59,7 +59,7 @@ using Flux.Losses: logitbinarycrossentropy
5959
using MLDatasets: MNIST
6060
using CUDA
6161
```
62-
<br>
62+
\
6363
Now we set default values for the learning rates, batch size, epochs, the usage of a GPU (if available) and other hyperparameters for our model.
6464

6565
```julia
@@ -116,7 +116,6 @@ We will also apply the weight initialization method mentioned in the original DC
116116
# sampled from a Gaussian distribution with μ=0 and σ=0.02
117117
dcgan_init(shape...) = randn(Float32, shape) * 0.02f0
118118
```
119-
<br>
120119

121120
```julia
122121
function Generator(latent_dim)
@@ -137,7 +136,7 @@ function Generator(latent_dim)
137136
)
138137
end
139138
```
140-
<br>
139+
\
141140
Time for a small test!! We create a dummy generator and feed a random vector as a seed to the generator. If our generator is initialized correctly it will return an array of size (28, 28, 1, `batch_size`). The `@assert` macro in Julia will raise an exception for the wrong output size.
142141

143142
```julia
@@ -150,10 +149,7 @@ gen_image = generator(noise)
150149
@assert size(gen_image) == (28, 28, 1, 3)
151150
```
152151

153-
<br>
154152
Our generator model is yet to learn the correct weights, so it does not produce a recognizable image for now. To train our poor generator we need its equal rival, the *discriminator*.
155-
<br>
156-
<br>
157153

158154
### Discriminator
159155

@@ -187,15 +183,14 @@ discriminator = Discriminator()
187183
logits = discriminator(gen_image)
188184
@assert size(logits) == (1, 3)
189185
```
190-
<br>
186+
191187
Just like our dummy generator, the untrained discriminator has no idea about what is a real or fake image. It needs to be trained alongside the generator to output positive values for real images, and negative values for fake images.
192188

193189
## Loss functions for GAN
194190

195191
In a GAN problem, there are only two labels involved: fake and real. So Binary CrossEntropy is an easy choice for a preliminary loss function.
196192

197193
But even if Flux's `binarycrossentropy` does the job for us, due to numerical stability it is always preferred to compute cross-entropy using logits. Flux provides [logitbinarycrossentropy](https://fluxml.ai/Flux.jl/stable/models/losses/#Flux.Losses.logitbinarycrossentropy) specifically for this purpose. Mathematically it is equivalent to `binarycrossentropy(σ(ŷ), y, kwargs...).`
198-
<br>
199194

200195
### Discriminator Loss
201196

@@ -213,15 +208,14 @@ function discriminator_loss(real_output, fake_output)
213208
return real_loss + fake_loss
214209
end
215210
```
216-
<br>
217211
### Generator Loss
218212

219213
The generator's loss quantifies how well it was able to trick the discriminator. Intuitively, if the generator is performing well, the discriminator will classify the fake images as real (or 1).
220214

221215
```julia
222216
generator_loss(fake_output) = logitbinarycrossentropy(fake_output, 1)
223217
```
224-
<br>
218+
\
225219
We also need optimizers for our network. Why you may ask? Read more [here](https://towardsdatascience.com/overview-of-various-optimizers-in-neural-networks-17c1be2df6d5). For both the generator and discriminator, we will use the [ADAM optimizer](https://fluxml.ai/Flux.jl/stable/training/optimisers/#Flux.Optimise.ADAM).
226220

227221
## Utility functions
@@ -254,7 +248,7 @@ function train_discriminator!(gen, disc, real_img, fake_img, opt, ps, hparams)
254248
return disc_loss
255249
end
256250
```
257-
<br>
251+
\
258252
We define a similar function for the generator.
259253

260254
```julia
@@ -268,8 +262,7 @@ function train_generator!(gen, disc, fake_img, opt, ps, hparams)
268262
return gen_loss
269263
end
270264
```
271-
<br>
272-
265+
\
273266
Now that we have defined every function we need, we integrate everything into a single `train` function where we first set up all the models and optimizers and then train the GAN for a specified number of epochs.
274267

275268
```julia
@@ -337,7 +330,7 @@ function train(hparams)
337330
return nothing
338331
end
339332
```
340-
<br>
333+
341334
Now we finally get to train the GAN:
342335

343336
```julia
@@ -359,10 +352,11 @@ images = load.(img_paths)
359352
gif_mat = cat(images..., dims=3)
360353
save("./output.gif", gif_mat)
361354
```
362-
<br>
363-
<p align="center">
364-
<img src="../../assets/tutorialposts/2021-10-8-dcgan-mnist/output.gif" align="middle" width="200">
365-
</p>
355+
~~~
356+
<div style="text-align:center">
357+
<img src="../../assets/tutorialposts/2021-10-08-dcgan-mnist/output.gif" width="200">
358+
</div>
359+
~~~
366360

367361
## Resources & References
368362
- [The DCGAN implementation in Model Zoo.](http=s://github.com/FluxML/model-zoo/blob/master/vision/dcgan_mnist/dcgan_mnist.jl)

0 commit comments

Comments
 (0)