You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: tutorialposts/2021-10-08-dcgan-mnist.md
+17-23
Original file line number
Diff line number
Diff line change
@@ -11,7 +11,7 @@ This is a beginner level tutorial for generating images of handwritten digits us
11
11
12
12
A GAN is composed of two sub-models - the **generator** and the **discriminator** acting against one another. The generator can be considered as an artist who draws (generates) new images that look real, whereas the discriminator is a critic who learns to tell real images apart from fakes.
The GAN starts with a generator and discriminator which have very little or no idea about the underlying data. During training, the generator progressively becomes better at creating images that look real, while the discriminator becomes better at telling them apart. The process reaches equilibrium when the discriminator can no longer distinguish real images from fakes.
17
17
@@ -24,9 +24,9 @@ This tutorial demonstrates the process of training a DC-GAN on the [MNIST datase
@@ -137,7 +136,7 @@ function Generator(latent_dim)
137
136
)
138
137
end
139
138
```
140
-
<br>
139
+
\
141
140
Time for a small test!! We create a dummy generator and feed a random vector as a seed to the generator. If our generator is initialized correctly it will return an array of size (28, 28, 1, `batch_size`). The `@assert` macro in Julia will raise an exception for the wrong output size.
142
141
143
142
```julia
@@ -150,10 +149,7 @@ gen_image = generator(noise)
150
149
@assertsize(gen_image) == (28, 28, 1, 3)
151
150
```
152
151
153
-
<br>
154
152
Our generator model is yet to learn the correct weights, so it does not produce a recognizable image for now. To train our poor generator we need its equal rival, the *discriminator*.
Just like our dummy generator, the untrained discriminator has no idea about what is a real or fake image. It needs to be trained alongside the generator to output positive values for real images, and negative values for fake images.
192
188
193
189
## Loss functions for GAN
194
190
195
191
In a GAN problem, there are only two labels involved: fake and real. So Binary CrossEntropy is an easy choice for a preliminary loss function.
196
192
197
193
But even if Flux's `binarycrossentropy` does the job for us, due to numerical stability it is always preferred to compute cross-entropy using logits. Flux provides [logitbinarycrossentropy](https://fluxml.ai/Flux.jl/stable/models/losses/#Flux.Losses.logitbinarycrossentropy) specifically for this purpose. Mathematically it is equivalent to `binarycrossentropy(σ(ŷ), y, kwargs...).`
198
-
<br>
199
194
200
195
### Discriminator Loss
201
196
@@ -213,15 +208,14 @@ function discriminator_loss(real_output, fake_output)
213
208
return real_loss + fake_loss
214
209
end
215
210
```
216
-
<br>
217
211
### Generator Loss
218
212
219
213
The generator's loss quantifies how well it was able to trick the discriminator. Intuitively, if the generator is performing well, the discriminator will classify the fake images as real (or 1).
We also need optimizers for our network. Why you may ask? Read more [here](https://towardsdatascience.com/overview-of-various-optimizers-in-neural-networks-17c1be2df6d5). For both the generator and discriminator, we will use the [ADAM optimizer](https://fluxml.ai/Flux.jl/stable/training/optimisers/#Flux.Optimise.ADAM).
Now that we have defined every function we need, we integrate everything into a single `train` function where we first set up all the models and optimizers and then train the GAN for a specified number of epochs.
0 commit comments