A Generative Adversarial Network is a combination of two sub-networks, which compete with each other while training in order to generate realistic data. A Generator Network generates genuine looking artificial data while Discriminator Network identifies if the data is artificial or real.

While GANs are powerful models, they can be rather difficult to train. We train both Generator and Discriminator simultaneously, at expense of one another. It is a dynamic system where as soon as the parameters of one model are updated, the nature of the optimization problem changes, and because of this, reaching convergence can be difficult.

The training can also result in the **failure of GANs** to model the complete distribution, and this is also called **Mode Collapse**.

In this article:

- we’ll see how to train a stable GAN model
- and then will play around with the training process to understand the possible reasons for mode failures.

I have been training GANs for the past few years, and I have observed, the usual failure modes in GANs are **Mode Collapse** and **Convergence Failure**, which we’ll discuss in this article.

## Training a stable GAN network

To understand how failure (in training GAN) can occur let’s first train a stable GAN network. We’ll use the MNIST dataset, our objective would be to generate artificial handwritten digits from random noise using the generator network.

The generator will take random noise as input, and the output will be the fake handwritten digits of size 28×28. The discriminator will take input 28×28 images from both generator and ground-truth, and will try to classify them correctly.

I have taken a learning rate of 0.0002, adam optimizer, and 0.5 as the momentum for adam optimizer.

Let’s have a look at the code of our stable GAN network. First, let’s make the required imports.

```
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.datasets as datasets
import numpy as np
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from tqdm import tqdm
import neptune
from neptune.types import File
```

Note that we’ll be using PyTorch for this exercise for training our model, and neptune.ai’s dashboard for experiment tracking. Here’s a link to all my experiments. I ran the scripts in colab and Neptune made it really easy to track all the experiments.

Proper experiment tracking, in this case, is really important because loss graphs and intermediate images can help a lot to identify if there’s a Failure mode. Alternatively, you can use matplotlib, sacred, TensorBoard, etc. as well depending on your use case and comfort.

We first initialize Neptune’s run, you can get the project path and API token once you create a project on the Neptune dashboard.

```
run = neptune.init_run(
project="project name",
api_token="You API token",
)
```

We are keeping the batch size as 1024 and we’ll run for 100 epochs. The latent dimension is initialized to generate random data for generator input. And sample size will be used to infer 64 images at each epoch so we can visualize the quality of images after each epoch. k is the number of steps we intend to run discriminator for.

```
batch_size = 1024
epochs = 100
sample_size = 64
latent_dim = 128
k = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
```

Now, we download the MNIST data and create Dataloader object.

```
train_data = datasets.MNIST(
root='../input/data',
train=True,
download=True,
transform=transform
)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
```

Finally, we define some hyperparameters for training and pass them to the Neptune dashboard using run object.

```
params = {"learning_rate": 0.0002,
"optimizer": "Adam",
"optimizer_betas": (0.5, 0.999),
"latent_dim": latent_dim}
run["parameters"] = params
```

This is where we define the generator and discriminator networks.

### Generator network

- Generator model takes the latent space as input, which is a random noise.
- In the first layer we change the latent space (of dimension 128) to feature space of 128 channels and each channel of height and width 7×7.
- Following two deconvolution layers, increase the height and the width of our feature space.
- Followed by a convolution layer with tanh activation to generate an image with one channel and and 28×28 height and width.

```
class Generator(nn.Module):
def __init__(self, latent_space):
super(Generator, self).__init__()
self.latent_space = latent_space
self.fcn = nn.Sequential(
nn.Linear(in_features=self.latent_space, out_features=128*7*7),
nn.LeakyReLU(0.2),
)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
nn.LeakyReLU(0.2),
nn.Conv2d(in_channels=128, out_channels=1, kernel_size=(3, 3), padding=(1, 1)),
nn.Tanh()
)
def forward(self, x):
x = self.fcn(x)
x = x.view(-1, 128, 7, 7)
x = self.deconv(x)
return x
```

### See also

### Discriminator network

- Our discriminator network consists of two convolutional layers to generate the features from the image coming from the generator and the real images.
- Followed by a classifier layer, which classifies if the image is predicted as real or fake by discriminator.

```
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
nn.LeakyReLU(0.2),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
nn.LeakyReLU(0.2)
)
self.classifier = nn.Sequential(
nn.Linear(in_features=3136, out_features=1),
nn.Sigmoid()
)
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
```

Now we intialize the generator and discriminator network, and the optimizers and loss function.

And we have some helper functions, to create labels for fake and real images (where size is the batch size) and create_noise function for generator input.

```
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
optim_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optim_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
criterion = nn.BCELoss()
def label_real(size):
labels = torch.ones(size, 1)
return labels.to(device)
def label_fake(size):
labels = torch.zeros(size, 1)
return labels.to(device)
def create_noise(sample_size, latent_dim):
return torch.randn(sample_size, latent_dim).to(device)
```

### Generator training function

Now we’ll train the generator:

- Generator takes in the random noise and give out the fake images.
- These fake images are then sent to discriminator, and now we minimize the loss between a real label and discriminator prediction of fake image.
- From this function we’ll be observing the generator loss.

```
def train_generator(optimizer, data_fake):
b_size = data_fake.size(0)
real_label = label_real(b_size)
optimizer.zero_grad()
output = discriminator(data_fake)
loss = criterion(output, real_label)
loss.backward()
optimizer.step()
return loss
```

### Discriminator training function

We create a function train_discriminator:

- This network as we know, takes input from ground truth (i.e real images) and generator network (i.e. fake images) while training.
- One after another, we pass fake and real image, calculate loss and backpropogate. We’ll be observing two discriminator losses; loss on real images (loss_real) and loss on fake images (loss_fake).

```
def train_discriminator(optimizer, data_real, data_fake):
b_size = data_real.size(0)
real_label = label_real(b_size)
fake_label = label_fake(b_size)
optimizer.zero_grad()
output_real = discriminator(data_real)
loss_real = criterion(output_real, real_label)
output_fake = discriminator(data_fake)
loss_fake = criterion(output_fake, fake_label)
loss_real.backward()
loss_fake.backward()
optimizer.step()
return loss_real, loss_fake
```

### GAN model training

Now that we have all the functions, let’s train our model and look at the observations to identify if the training is stable or not.

- The noise in the first line will be used to infer intermediate images after each epoch. We are keeping the noise same so we can compare images on different epochs.
- Now for each epoch we train the discriminator k times (one time in this case as k=1), for each time generator is trained.
- All the losses are recorded and sent to Neptune dashboard for plotting. We don’t need to append them in a list, using Neptune dashboard we can plot the loss graphs on the fly. It will also record loss at each step in a .csv file.
- I have saved the generated images after each epoch in Neptune metadata using [dot]upload function.

```
noise = create_noise(sample_size, latent_dim)
generator.train()
discriminator.train()
for epoch in range(epochs):
loss_g = 0.0
loss_d_real = 0.0
loss_d_fake = 0.0
# training
for bi, data in tqdm(enumerate(train_loader), total=int(len(train_data) / train_loader.batch_size)):
image, _ = data
image = image.to(device)
b_size = len(image)
for step in range(k):
data_fake = generator(create_noise(b_size, latent_dim)).detach()
data_real = image
loss_d_fake_real = train_discriminator(optim_d, data_real, data_fake)
loss_d_real += loss_d_fake_real[0]
loss_d_fake += loss_d_fake_real[1]
data_fake = generator(create_noise(b_size, latent_dim))
loss_g += train_generator(optim_g, data_fake)
# inference and observations
generated_img = generator(noise).cpu().detach()
generated_img = make_grid(generated_img)
generated_img = np.moveaxis(generated_img.numpy(), 0, -1)
run[f'generated_img/{epoch}'].upload(File.as_image(generated_img))
epoch_loss_g = loss_g / bi
epoch_loss_d_real = loss_d_real/bi
epoch_loss_d_fake = loss_d_fake/bi
run["train/loss_generator"].log(epoch_loss_g)
run["train/loss_discriminator_real"].log(epoch_loss_d_real)
run["train/loss_discriminator_fake"].log(epoch_loss_d_fake)
print(f"Epoch {epoch} of {epochs}")
print(f"Generator loss: {epoch_loss_g:.8f}, Discriminator loss fake: {epoch_loss_d_fake:.8f}, Discriminator loss real: {epoch_loss_d_real:.8f}")
```

Let’s have a look at the intermediate images.

#### Epoch 10

These are 64 digits generated at epoch 10.

#### Epoch 100

These are generated using the same noise at epoch 100. These look much better than images at epoch 10, here we can actually identify different digits. We can train for even more epochs or tune hyperparameters for better quality images.

### Loss graphs

You can easily go into “Add New Dashboard” in your Neptune dashboard, and merge different loss graphs into one.

In Fig. 3, you can observe the losses stabilizing after epoch 40. Discriminator loss for the real and fake images remains around 0.6 whereas for the generator it is around 0.8. The above graph is the expected graph for stable training. We can consider this as a baseline and experiment with changing k (training steps for discriminator), increasing the number of epochs, etc.

Now that we have built a stable GAN model, let’s look at the failure modes.

### Related

## GAN failure modes

In the past years, we have seen a rapid increase in GAN applications, whether it be to increase the resolution of images, conditional generation, or generation of real-like synthetic data.

Failure of training is a difficult problem for such applications.

**How to identify GAN failure modes? How do we know if there’s a failure mode:**

- The generator should ideally produce a variety of data. If it’s producing a single kind or a similar set of output, there’s a
**Mode Collapse**. - When a visually bad set of data is getting generated, this might be a case of
**Convergence Failure**.

**What causes mode collapse in GAN? Reasons for failure modes:**

- Inability to find convergence for networks.
- The generator can find a certain type of data that can easily fool the discriminator. It’ll, again and again, generate the same data under the assumption that the goal is achieved. The entire system can over-optimize to that single type of output.

The problem with identifying mode collapse and other failure modes is that we can not rely on qualitative analysis (like manually looking at the data). This method can fail if there’s a huge amount of data or if the problem is really complex (we won’t always be generating digits).

## Evaluating failure modes

In this section, we’ll try to understand how to identify if there’s a mode collapse or convergence failure. We’ll see three methods of evaluation. One of which we have already discussed in the previous section.

### Looking at the intermediate images

Let’s see some examples, where, from the intermediate images can evaluate the mode collapse and convergence. In Fig. 4 we see really bad quality images and in Fig. 5 we can see the same set of images generated.

While Fig. 4 is an example of convergence failure, Fig. 5 shows mode collapse. You can get an idea of how your model is performing by looking at the images manually. But when the problem complexity is high or the training data is too big, you might not be able to identify mode collapse.

Let’s look at some better methods.

### By observing loss graphs

We can know a lot about what’s happening by looking at the loss graphs. For example, in Fig 3 you can notice losses saturating after a certain point, showing the expected behavior. Now let’s look at this loss graph in Fig. 6, where I have reduced the latent dimension and so the behavior is erratic.

We can see in Fig. 6, generator loss is oscillating around 1 and 1.2. While discriminator losses for fake and real images also hang around 0.6, the loss is somewhat more than what we noticed in the stable version.

I would advise, even if the graph has a high variance, it’s fine. You can increase the number of epochs and wait for some more time for it to get stable and most importantly keep checking intermediate images generated.

If a loss graph drops down to zero in the initial epochs for both generator and discriminator, then it is a problem as well. It means the generator has found a set of fake images really easy for the discriminator to identify.

### Number of statistically-different bins (NDB Score)

Unlike the above two qualitative methods, the NDB score is a quantitative method. So instead of looking over images and loss graphs, and missing something or not making the correct interpretation, the NDB score can identify if there’s a mode collapse.

Let’s understand how NDB scoring works:

- We have two sets, a training set (on which model is trained) and a test set (fake images generated by the generator on random noise after training is done).
- Now divide the training set into K number of clusters using K-means clustering. These will be our K different bins.
- Now allocate the test data to these K bins, on the basis of Euclidean distance between test data points and centroids of K clusters.
- Now conduct a two-sample test between train and test samples for each bin and calculate the Z-Score. If Z-score is smaller than the threshold value (0.05 is used in the paper) mark the bin as statistically different.
- Count the number of statistically different bins and divide them by K.
- The value received would lie between 0 and 1.

A high number of statistically different bins means i.e. value nearer to 1, means high mode collapse, means bad model. However, NDB scores nearing 0 means lesser to no mode collapse.

NDB evaluation method comes from paper On GANs and GMMs.

A very well-implemented code for calculating NDB can be found in this colab notebook by Kevin Shen.

## Solving failure modes

Now that we have an understanding of how to identify the problems in the training of GANs, we’ll look at some of the solutions and rules of thumb to solve them. Some of these will be basic hyperparameter tunings. We’ll discuss some algorithms if you want to go the extra mile to stabilize your GANs.

### Cost functions

There are papers that say no loss function is superior. I would suggest you start with the easier loss functions like we are using Binary Cross-Entropy and level up from there.

Now it’s not a compulsion to use certain loss functions with certain GAN architectures. But a lot of research went into writing these papers, a lot of it is still active. So it would be good practice to use these loss functions in Fig. 8, which might help you prevent both mode collapse and convergence.

Experiment on different loss functions, and note that your loss function might be failing because of wrong tuning of hyperparameters, like making the optimizer too aggressive, or a large learning rate. We’ll talk more about these problems later.

### Latent space

Latent space is from where the input for the generator (random noise) is sampled. Now if you restrict the latent space it will produce more outputs of the same type, evident from Fig. 9. You can also look at the corresponding loss graph in Fig. 6.

In Fig. 9 can you see so many similar 8s and 7s? Hence the mode collapse.

Note that while training a GAN network it is vital to give a sufficient amount of latent space, so the generator can create a variety of features.

### Learning rate

One of the most common issues I have observed while training GANs is a high learning rate. It leads to either mode collapse or non-convergence. It’s really important that you keep the learning rate low, as low as 0.0002 or even lower.

We can clearly see from the loss graph in Fig. 12, that discriminator is identifying all the images as real. That’s why loss for fake images is high and real images is zero. Now generator is under the assumption all images produced by it are fooling the discriminator. The problem here is discriminator is not getting trained even a little because of such a high learning rate.

The higher the batch size, the higher can be the value of the learning rate, but always try to be on the safer side.

### Optimizer

An aggressive modifier is bad news for training GANs. It results in the inability to find the equilibrium between generator loss and discriminator loss, and hence convergence failure.

In Adam Optimizer, betas are the hyperparameters used to calculate the running average of gradient and its square. We initially (in the stable training) used the value 0.5 for beta1. Changing it to 0.9 (default value) increases the aggressiveness of the optimizer.

In Fig. 14, The discriminator is performing well. Since the generator loss is increasing we can tell that it is producing such bad images that it’s really easy for the discriminator to classify them as fake. The loss graph does not achieve equilibrium.

### Feature matching

Feature matching suggests a new objective function, where we do not directly use the discriminator output. The generator is trained such that the generator output is expected to match values of real images on intermediate features of the discriminator.

For real image and fake image, features vectors(f(x) in Fig. 15) are computed on the intermediate layer in mini-batches and L2 distance on means of these feature vectors is measured.

It makes more sense to match the generated data to the statistics of the real data. In case the optimizer turns too greedy in search of the best data generation, and never reaches convergence, feature matching can be helpful.

### Historical averaging

We keep a running average of parameters(θ) of the previous t number of models. Now we penalize the model, adding an L2 cost to the cost function using previous parameters.

Here, θ[i] is the parameter value on i^{th} run.

When dealing with non-convex objective functions, historical averaging can help converge the model.

## Conclusion

- We now understand the importance of experiment tracking while training GANs.
- It is important to understand the loss graphs and carefully observe the intermediate data generated.
- Hyperparameters like learning rate, optimizer parameters, latent space, etc. can ruin your model if not tuned properly.
- With the increase in GAN models in the past few years, more and more research is going into stabilizing the training of GAN. There are a lot more techniques beneficial for specific use-cases.