-
Notifications
You must be signed in to change notification settings - Fork 7
/
8.5-introduction-to-gans.Rmd
262 lines (195 loc) · 11.9 KB
/
8.5-introduction-to-gans.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
---
title: "Introduction to generative adversarial networks"
output:
html_notebook:
theme: cerulean
highlight: textmate
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(warning = FALSE, message = FALSE)
```
***
This notebook contains the code samples found in Chapter 8, Section 5 of [Deep Learning with R](https://www.manning.com/books/deep-learning-with-r). Note that the original text features far more content, in particular further explanations and figures: in this notebook, you will only find source code and related comments.
***
## A schematic GAN implementation
In this section, we'll explain how to implement a GAN in Keras, in its barest form -- because GANs are advanced, diving deeply into the technical details would be out of scope for this book. The specific implementation is a _deep convolutional GAN_ (DCGAN): a GAN where the generator and discriminator are deep convnets. In particular, it uses a `layer_conv_2d_transpose()` for image upsampling in the generator.
We will train our GAN on images from CIFAR10, a dataset of 50,000 32x32 RGB images belong to 10 classes (5,000 images per class). To make
things even easier, we will only use images belonging to the class "frog".
Schematically, our GAN looks like this:
* A `generator` network maps vectors of shape `(latent_dim)` to images of shape `(32, 32, 3)`.
* A `discriminator` network maps images of shape (32, 32, 3) to a binary score estimating the probability that the image is real.
* A `gan` network chains the generator and the discriminator together: `gan(x) <- discriminator(generator(x))`. Thus this `gan` network maps latent space vectors to the discriminator's assessment of the realism of these latent vectors as decoded by the generator.
* We train the discriminator using examples of real and fake images along with "real"/"fake" labels, as we would train any regular image classification model.
* To train the generator, we use the gradients of the generator's weights with regard to the loss of the `gan` model. This means that, at every step, we move the weights of the generator in a direction that will make the discriminator more likely to classify as "real" the images decoded by the generator. I.e. we train the generator to fool the discriminator.
## A bag of tricks
Training GANs and tuning GAN implementations is notoriously difficult. There are a number of known "tricks" that one should keep in mind. Like most things in deep learning, it is more alchemy than science: these tricks are really just heuristics, not theory-backed guidelines. They are backed by some level of intuitive understanding of the phenomenon at hand, and they are known to work well empirically, albeit not necessarily in every context.
Here are a few of the tricks that we leverage in our own implementation of a GAN generator and discriminator below. It is not an exhaustive list of GAN-related tricks; you will find many more across the GAN literature.
* We use `tanh` as the last activation in the generator, instead of `sigmoid`, which is more commonly found in other types of models.
* We sample points from the latent space using a _normal distribution_ (Gaussian distribution), not a uniform distribution.
* Stochasticity is good to induce robustness. Because GAN training results in a dynamic equilibrium, GANs are likely to get stuck in all sorts of ways. Introducing randomness during training helps prevent this. We introduce randomness in two ways: by using dropout in the discriminator and by adding random noise to the labels for the discriminator.
* Sparse gradients can hinder GAN training. In deep learning, sparsity is often a desirable property, but not in GANs. Two things can induce gradient sparsity: max pooling operations and ReLU activations. Instead of max pooling, we recommend using strided convolutions for downsampling, and we recommend using a `layer_activation_leaky_relu()` instead of a ReLU activation. It's similar to ReLU, but it relaxes sparsity constraints by allowing small negative activation values.
* In generated images, it's common to see checkerboard artifacts caused by unequal coverage of the pixel space in the generator (see figure 8.17). To fix this, we use a kernel size that is divisible by the stride size whenever we use a strided `layer_conv_2d_transpose()` or `layer_conv_2d()` in both the generator and the discriminator.
## The generator
First, we develop a `generator` model, which turns a vector (from the latent space -- during training it will sampled at random) into a candidate image. One of the many issues that commonly arise with GANs is that the generator gets stuck with generated images that look like noise. A possible solution is to use dropout on both the discriminator and generator.
```{r}
library(keras)
latent_dim <- 32
height <- 32
width <- 32
channels <- 3
generator_input <- layer_input(shape = c(latent_dim))
generator_output <- generator_input %>%
# First, transform the input into a 16x16 128-channels feature map
layer_dense(units = 128 * 16 * 16) %>%
layer_activation_leaky_relu() %>%
layer_reshape(target_shape = c(16, 16, 128)) %>%
# Then, add a convolution layer
layer_conv_2d(filters = 256, kernel_size = 5,
padding = "same") %>%
layer_activation_leaky_relu() %>%
# Upsample to 32x32
layer_conv_2d_transpose(filters = 256, kernel_size = 4,
strides = 2, padding = "same") %>%
layer_activation_leaky_relu() %>%
# Few more conv layers
layer_conv_2d(filters = 256, kernel_size = 5,
padding = "same") %>%
layer_activation_leaky_relu() %>%
layer_conv_2d(filters = 256, kernel_size = 5,
padding = "same") %>%
layer_activation_leaky_relu() %>%
# Produce a 32x32 1-channel feature map
layer_conv_2d(filters = channels, kernel_size = 7,
activation = "tanh", padding = "same")
generator <- keras_model(generator_input, generator_output)
summary(generator)
```
## The discriminator
Then, we develop a `discriminator` model, that takes as input a candidate image (real or synthetic) and classifies it into one of two classes, either "generated image" or "real image that comes from the training set".
```{r}
discriminator_input <- layer_input(shape = c(height, width, channels))
discriminator_output <- discriminator_input %>%
layer_conv_2d(filters = 128, kernel_size = 3) %>%
layer_activation_leaky_relu() %>%
layer_conv_2d(filters = 128, kernel_size = 4, strides = 2) %>%
layer_activation_leaky_relu() %>%
layer_conv_2d(filters = 128, kernel_size = 4, strides = 2) %>%
layer_activation_leaky_relu() %>%
layer_conv_2d(filters = 128, kernel_size = 4, strides = 2) %>%
layer_activation_leaky_relu() %>%
layer_flatten() %>%
# One dropout layer - important trick!
layer_dropout(rate = 0.4) %>%
# Classification layer
layer_dense(units = 1, activation = "sigmoid")
discriminator <- keras_model(discriminator_input, discriminator_output)
summary(discriminator)
# To stabilize training, we use learning rate decay
# and gradient clipping (by value) in the optimizer.
discriminator_optimizer <- optimizer_rmsprop(
lr = 0.0008,
clipvalue = 1.0,
decay = 1e-8
)
discriminator %>% compile(
optimizer = discriminator_optimizer,
loss = "binary_crossentropy"
)
```
## The adversarial network
Finally, we setup the GAN, which chains the generator and the discriminator. This is the model that, when trained, will move the generator in a direction that improves its ability to fool the discriminator. This model turns latent space points into a classification decision, "fake" or "real", and it is meant to be trained with labels that are always "these are real images". So training `gan` will updates the weights of `generator` in a way that makes `discriminator` more likely to predict "real" when looking at fake images. Very importantly, we set the discriminator to be frozen during training (non-trainable): its weights will not be updated when training `gan`. If the discriminator weights could be updated during this process, then we would be training the discriminator to always predict "real", which is not what we want!
```{r}
# Set discriminator weights to non-trainable
# (will only apply to the `gan` model)
freeze_weights(discriminator)
gan_input <- layer_input(shape = c(latent_dim))
gan_output <- discriminator(generator(gan_input))
gan <- keras_model(gan_input, gan_output)
gan_optimizer <- optimizer_rmsprop(
lr = 0.0004,
clipvalue = 1.0,
decay = 1e-8
)
gan %>% compile(
optimizer = gan_optimizer,
loss = "binary_crossentropy"
)
```
## How to train your DCGAN
Now we can begin training. To recapitulate, this is what the training loop looks like schematically. For each epoch, we do the following:
* Draw random points in the latent space (random noise).
* Generate images with `generator` using this random noise.
* Mix the generated images with real ones.
* Train `discriminator` using these mixed images, with corresponding targets: either "real" (for the real images) or "fake" (for the generated images).
* Draw new random points in the latent space.
* Train `gan` using these random vectors, with targets that all say "these are real images." This updates the weights of the generator (only, because the discriminator is frozen inside `gan`) to move them toward getting the discriminator to predict "these are real images" for generated images: that is, this trains the generator to fool the discriminator.
Let's implement it.
```{r, echo=TRUE, results='hide'}
# Loads CIFAR10 data
cifar10 <- dataset_cifar10()
c(c(x_train, y_train), c(x_test, y_test)) %<-% cifar10
# Selects frog images (class 6)
x_train <- x_train[as.integer(y_train) == 6,,,]
# Normalizes data
x_train <- x_train / 255
iterations <- 10000
batch_size <- 20
save_dir <- "gan_images"
dir.create(save_dir)
# Start the training loop
start <- 1
for (step in 1:iterations) {
# Samples random points in the latent space
random_latent_vectors <- matrix(rnorm(batch_size * latent_dim),
nrow = batch_size, ncol = latent_dim)
# Decodes them to fake images
generated_images <- generator %>% predict(random_latent_vectors)
# Combines them with real images
stop <- start + batch_size - 1
real_images <- x_train[start:stop,,,]
rows <- nrow(real_images)
combined_images <- array(0, dim = c(rows * 2, dim(real_images)[-1]))
combined_images[1:rows,,,] <- generated_images
combined_images[(rows+1):(rows*2),,,] <- real_images
# Assembles labels discriminating real from fake images
labels <- rbind(matrix(1, nrow = batch_size, ncol = 1),
matrix(0, nrow = batch_size, ncol = 1))
# Adds random noise to the labels -- an important trick!
labels <- labels + (0.5 * array(runif(prod(dim(labels))),
dim = dim(labels)))
# Trains the discriminator
d_loss <- discriminator %>% train_on_batch(combined_images, labels)
# Samples random points in the latent space
random_latent_vectors <- matrix(rnorm(batch_size * latent_dim),
nrow = batch_size, ncol = latent_dim)
# Assembles labels that say "all real images"
misleading_targets <- array(0, dim = c(batch_size, 1))
# Trains the generator (via the gan model, where the
# discriminator weights are frozen)
a_loss <- gan %>% train_on_batch(
random_latent_vectors,
misleading_targets
)
start <- start + batch_size
if (start > (nrow(x_train) - batch_size))
start <- 1
# Occasionally saves images
if (step %% 100 == 0) {
# Saves model weights
save_model_weights_hdf5(gan, "gan.h5")
# Prints metrics
cat("discriminator loss:", d_loss, "\n")
cat("adversarial loss:", a_loss, "\n")
# Saves one generated image
image_array_save(
generated_images[1,,,] * 255,
path = file.path(save_dir, paste0("generated_frog", step, ".png"))
)
# Saves one real image for comparison
image_array_save(
real_images[1,,,] * 255,
path = file.path(save_dir, paste0("real_frog", step, ".png"))
)
}
}
```