I decided to start my blog by looking at the the neoclassics of AI – the paper that introduced Generative Adversarial Networks [1].
In this post I want to illustrate the difficulties that GANs encounter during training. I will try to focus on specifics, subtleties and intuitions behind the GAN model. I think that this post will be most useful for people with some background in machine learning and some familiarity with GANs (on a level of an introductory post), but I hope that the illustrations will be fun to look at in any case.
GAN refresher
The main idea of GANs is intuitive and fun: we have a generator (G) and a discriminator (D) that compete against each other. A discriminator is trained to distinguish between the genuine samples (samples that come from the true data distribution) and the fake samples (samples produced by some other process). A generator is trained to fool the discriminator by producing samples that it would mistakenly classify as genuine. If it succeeds, we would be able to get samples from our target distribution (if it does not sound exciting – be sure to take a look at some applications).
Since a traditional neural network is deterministic, we need to give the generator some source of randomness (otherwise it would always output the same thing, which would make and the discriminator’s job too easy and the model useless). The cool thing is that we could start with a fairly arbitrary distribution which we could easily sample from and let the model transform it into what we want it to be.
Let’s illustrate this transformation idea with an example from a tutorial on Variational Autoencoders [2] : suppose that we have a bivariate normal distribution sampler, but we want to sample from a circle-shaped 2D distribution. We could apply a transformation to get what we want:
In case of GANs we let the network infer the transformation for us, but the idea is similar.
The question of what is a suitable initial distribution is not entirely trivial, and I may cover it in more detail in a separate post. For now let’s assume that it would work as long as we provide “enough” randomness (i.e. our latent variable is continuous and has enough dimensions).
The widespread usage of relatively low dimensions in latent feature distributions hinges, therefore, on the manifold hypothesis. This hypothesis states that the data we observe actually lies on some manifold, a surface in a high-dimensional space and is ubiquitous in ML. For example, the circle distribution above is essentially one-dimensional (one line), even though it “lives” in a 2-d space.
These intuitions are summarized in the following training objective (also called a game value function, since the discriminator (D) and generator (G) play an adversarial game against each other):
High negative values of correspond to the generator “winning”, high positive values are a success for the discriminator.
The rightmost expectation reflects the “success” of the generator. This expectation is taken with respect to a random variable * that we sample from some distribution . It just says that we sample this variable, feed it to the generator, and then feed a resulting fake datapoint to the discriminator. If the discriminator confidently thinks that it is a real image, it would output something close to 1, and the overall value of the term under the expectation will be close to . It would essentially mean that the generator is losing the game.
*Note that could be a multivariate variable. A common choice is multivariate normal with unit variance.
The leftmost expectation is only relevant to the discriminator. It just says that when we sample a datapoint from the true data distribution, the discriminator should aim to give it a high confidence score. The poor discriminator needs to balance giving high confidence to the real data, while not giving high confidence to the fake examples made by the generator. I would say that the game is not really fair, and the paper presents a formal proof that the game equilibrium reached when the generator learns the true data distribution and the discriminator gives up and admits its incompetence, always predicting 0.5.
Why do we need to learn this distribution?
The beauty of the idea is that the target distribution could be almost anything. It could be as simple as a one-dimensional uniform, or as complicated as a distribution over all possible images in the world*. In this post, I aim for understanding and truth, as opposed to Beauty, so I resort to experiments on GANs learning 1 and 2 dimensional distributions.
*In the traditional framework, the data should be continuous, but this restriction could be relaxed.
GAN learning issues visualization
(Non) Convergence
Unfortunately, the adversarial nature of the training procedure leads to GANs often being stubbornly non-converging. The original paper, however, presents a nice illustration where a GAN is learning a one-dimensional distribution. I decided to see how would this picture look in real life.
What I saw was an eternal struggle between a red histogram (generator) that tries to sneak under the green line (discriminator) and the green line that tries to cover the blue region (observed data), while not covering the pesky red histogram.
In theory, the green line should eventually give up and settle to a compromise solution, while the red histogram should settle to the true data distribution. But this theory assumes infinite data and infinite model capacity.
As you could see, it is not quite happening in practice, although the red line does learn something resembling the blue distribution. It is interesting that the model that fails to converge even in this toy example is able to produce state of the art results in a range of domains.
What I also find amusing is that the model manages to entangle itself and produce a non-monotonous mapping even in the simple 1-D case. It makes me wonder, would the entanglement become more intense in 2/3/n dimensions?
Mode collapse
Another well-known issue with GANs is mode collapse. Essentially, the generator might decide to focus on one of the modes of the target distribution instead of modeling it all. It might also jump from mode to mode, in order to avoid the discriminator that would continue to pursue it.
A gif below depicts a typical scenario: some mode or modes are ignored and the model is unstable. From time to time an utter chaos ensues: the model moves all of its predictions to a new mode, forgets the previously learned ones, predictions spread all over the place, etc.. Eventually, however, everything usually goes back to normal.
In practice, we don’t have the luxury of seeing the predicted distributions so clearly, and so stopping training at an unfortunate time may present a serious issue.
Catastrophic forgetting
It is interesting to note that the discriminator confidence does not go all the way to zero between the clusters. It is not intuitive: there are no data, so there is no motivation for the generator to keep its predictions above zero in these regions (why give free points to the generator?). My first assumption was that the network was not powerful enough. However, that was not the case: increasing the capacity (number of neurons) of the model does not change this behavior.
I think that a likely explanation is that this is an example of catastrophic forgetting, which is a problem common to most Neural Network architectures. Roughly speaking, they forget what they’ve learned when we teach them something new. In our case, since there are no (or very few) generated samples in the “between cluster space”, the discriminator is not motivated to keep its confidence low in those spots, and it starts to drift back. An interesting question is whether GAN learning could be improved by using the methods of preventing catastrophic forgetting ([3], for example). *After some googling*: apparently this connection was recently explored in a fresh archive pre-print [4], but it seems that it might still be an interesting research direction.
Failure cases
Another noticeable problem is the “swaying” motion of the mapping in the beginning of learning. It might happen when the decision surface initially is of such a shape that locally it is beneficial for the generator to move away from the data region. Sometimes it results in utter learning failures: the generator moves too far from the data, after which the discriminator collapses its confidence outside the data region. As a result, the generator gets stuck and never finds its way back:
Moving to 2D
One dimensional case is very unnatural and limiting, so I decided to move to 2D distribution learning. Our data consists of five normally distributed clusters:
I trained a vanilla GAN for 2000 weight updates. Below are three different representations of the learning process.
On the first video, I visualize the data distribution (static purple clusters) and the mapping from latent space to the observed space (arrows). This mapping changes as the model is trained, producing an extremely amusing (to my taste) animation.
Note that the latent space is standard normal, so most (~95%) of the latent variables would lie be in the (-2, 2) interval. Since examples outside of this region are rare, their mapping is not very meaningful and makes the picture cluttered.
The mapping is highly entangled, and oscillates for a long time, illustrating the lack of convergence.
The second video demonstrates an interplay between the generator and discriminator. Generated samples are visualized through as red patches (the deeper the color – the more patches fit into that region, while the discriminator decisions are visualized as the purple-white contour plot (the discriminator thinks that white regions are “fake data”, while the purple regions are classified as “true data”).
Here we could see how the red patches are trying to be on top of the purple points, while the purple points are trying to cover the data and avoid being covered by the red dots. It leads to a mild case of the “mode collapse”. We see that the red dots change their concentration on different clusters.
The third video is similar to the second one, but visualized in 3D. The red patches are the regions where generator places its samples, while the mountains depict the discriminator’s decision surface (normalized for convenience).
This video nicely illustrates the mode collapse issue: we see that even during the late stages of learning, the peaks often disappear or grow very small, which means that the discriminator “forgets” about certain clusters (probably since there were too many fake examples). The red patches then also leave those abandoned clusters, pursuing the generator. After that, the peaks grow back and the cycle repeats.
Overall, I hope that these examples helped to demonstrate why GANs have convergence issues. In the next posts, I may look at GAN modifications that are aimed to eliminate or mitigate these and other issues.
My (unpolished) code with vanilla GAN implementation and an IPython notebook for generating visualizations in this post are available at https://github.com/R-seny/time-to-reproduce
If you have any comments, or suggestions for further posts – let me know!