MLOps Blog

Pix2pix: Key Model Architecture Decisions

13 min
Nilesh Barla
25th April, 2023

Generative Adversarial Networks or GANs is a type of neural network that belongs to the class of unsupervised learning. It is used for the task of deep generative modeling. 

In deep generative modeling, the deep neural networks learn a probability distribution over a given set of data points and generate similar data points. Since it is an unsupervised learning task it doesn’t use any labels during the learning process. 

Since its release in 2014, the deep learning community has been actively developing new GANs to improve the field of generative modeling. This article aims to provide information on GANs, specifically Pix2Pix GAN, which is one of the most used generative models.

What is GAN?

GANs was designed by Ian Goodfellow in 2014. GANs main intention was to generate samples that were not blurry and had rich representations of features. Discriminative models were doing good on this front as they were able to classify between different classes. Deep generative models on the other hand were far less effective, due to the difficulty in approximating many intractable probabilistic computations which are quite evident in Autoencoders.  

Autoencoders and their variants are explicit likelihood models which means they explicitly compute the probability density function over a given distribution. GANs and their variants are implicit likelihood models which means they don’t compute the probability density function but rather learn the underlying distribution. 

GANs learn the underlying distribution by approaching the whole problem as a binary classification problem. In this approach, the problem model is presented by two models: a generator and a discriminator. The job of the generator is to generate new samples and the job of the discriminator is to classify or discriminate if the sample produced by the generator is real or fake. 

The two models are trained together in a zero-sum game until the generator can produce samples that are similar to real samples. Or in other words, they are trained until the generator can fool the discriminator. 

Architecture of a vanilla GAN

Let’s briefly understand the architecture of GANs. From this section onward most of the topics will be explained using codes. So to begin with let’s define all the required dependencies. 

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

Generator

The Generator is a component in GAN that takes in noise which by definition is Gaussian distribution and yields samples similar to the original dataset. As GANs have evolved over the years they have adopted the use of CNNs which is quite prominent in computer vision tasks. But for simplicity, we will define it with just linear functions using Pytorch. 

class Generator(nn.Module):
   def __init__(self, z_dim, img_dim):
       super().__init__()
       self.gen = nn.Sequential(
           nn.Linear(z_dim, 256),
           nn.LeakyReLU(0.01),
           nn.Linear(256, img_dim),
           nn.Tanh(),  # normalize inputs to [-1, 1] so make outputs [-1, 1]
       )

   def forward(self, x):
       return self.gen(x)

Discriminator

The discriminator is simply a classifier that classifies whether the data yielded by the generator is real or fake. It does this by learning the original distribution from the real data and then evaluating between the two. We will keep things simple and define the discriminator using linear functions. 

class Discriminator(nn.Module):
   def __init__(self, in_features):
       super().__init__()
       self.disc = nn.Sequential(
           nn.Linear(in_features, 128),
           nn.LeakyReLU(0.01),
           nn.Linear(128, 1),
           nn.Sigmoid(),
       )

   def forward(self, x):
       return self.disc(x)

The key difference between the generator and the discriminator is the last layer. The former yields the same shape as that of the image while the latter yields only one output, either 0 or 1. 

Loss function and training

The loss function is one of the most important components in any deep learning algorithm. For instance, if we design a CNN to minimize the Euclidean distance between the ground truth, and predicted results, it will tend to produce blurry results. This is because Euclidean distance is minimized by averaging all plausible outputs, that cause blurring

See also

Understanding GAN Loss Functions

The above point is an important one that we must keep in mind. With that being said the loss function that we will use for vanilla GAN will be binary cross-entropy loss or BCELoss because we are performing binary classification. 

criterion = nn.BCELoss()

Now let’s define the optimization method and other related parameters.

opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)


device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1  # 784
batch_size = 32
num_epochs = 100

dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0

Let’s understand the training loop. The training loop of GAN start with:

  1. Generating samples from the generator using Gaussian distribution 
  2. Training the discriminator using real data and fake data produced by generator
  3. Updating the discriminator
  4. Updating the generator

Here’s how the training loop looks like: 

for epoch in range(num_epochs):
   for batch_idx, (real, _) in enumerate(loader):
       real = real.view(-1, 784).to(device)
       batch_size = real.shape[0]
       ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
       noise = torch.randn(batch_size, z_dim).to(device)
       fake = gen(noise)
       disc_real = disc(real).view(-1)
       lossD_real = criterion(disc_real, torch.ones_like(disc_real))
       disc_fake = disc(fake).view(-1)
       lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
       lossD = (lossD_real + lossD_fake) / 2
       disc.zero_grad()
       lossD.backward(retain_graph=True)
       opt_disc.step()
       ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
       # where the second option of maximizing doesn't suffer from
       # saturating gradients
       output = disc(fake).view(-1)
       lossG = criterion(output, torch.ones_like(output))
       gen.zero_grad()
       lossG.backward()
       opt_gen.step()
       if batch_idx == 0:
           print(
               f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)}
                     Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
           )
           with torch.no_grad():
               fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
               data = real.reshape(-1, 1, 28, 28)
               img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
               img_grid_real = torchvision.utils.make_grid(data, normalize=True)
               writer_fake.add_image(
                   "Mnist Fake Images", img_grid_fake, global_step=step
               )
               writer_real.add_image(
                   "Mnist Real Images", img_grid_real, global_step=step
               )
               step += 1

Important points from the loop above:

  1. The loss function for the discriminator is calculated two times: one for real images and another for fake images.
    • For real images, the ground truth is converted to ones using the torch.ones_like function which returns a matrix of ones of defined shape.  
    • For fake images, the ground truth is converted to ones using the torch.zeros_like function which returns a matrix of zeros of defined shape.  
  2. The loss function for the generator is calculated only once. If you observe carefully it is the same loss function that is used by the discriminator to calculate the loss for fake images. The only difference is that instead of using the torch.zeros_like function, torch.ones_like function is used. The interchanging of labels from 0 to 1 enables the generator to learn representations that will produce real images, therefore fooling the discriminator. 

Mathematically, we can define the whole process as:

Loss function and training

Where Z is the noise, x is the real data, G is the generator and D is the discriminator.

Application of GANs

GANs are widely used for:

  • Generating training samples: GANs are often used to generate samples for specific tasks like, for classification of malign and benign cancer cells, especially where the data is less to train a classifier. 
  • AI Art or Generative Art: AI or Generative art is another new domain where GANs are extensively used. Since the introduction of non-fungible tokens, artists all the world are creating art in unorthodox fashion i.e. digital and generative. GANs like DeepDaze, BigSleep, BigGAN, CLIP, VQGAN et cetera are most commonly used by creators. 
AI Art or Generative Art
AI Art or Generative Art | Source: Author
  • Image-to-image translation: Image-to-image translation is again used by digital creators. The idea here is to translate a certain type of image to an image in the target domain. For example a day-light image into a night image or a winter image to a summer image (see the image below). GANs like pix2pix, cycleGAN, styleGAN are few of the most popular GANs. 
Image-to-image translation
Image-to-image translation | Source
  • Text-to-image translation: Text-to-image translation is simply converting a text or a given string into an image. This is a pretty hot domain as of now and it is a growing community. As mentioned previously GANs such as DeepDaze, BigSleep and DALL·E from OpenAI are quite popular for this.
Text-to-image translation
Text-to-image translation | Source 

Issues with GANs

Although GANs can produce images from random Gaussian distribution that are similar to real images, this process is not perfect most of the time. Here’s why:

  • Mode Collapse: Mode collapse refers to the issue when the generator is able to fool the discriminator by learning with less data samples from the overall data. Because of mode collapse the GAN is not able to learn wide variety of distribution and remains limited to a few. 
  • Diminished gradient: Diminished or vanishing gradient descent occurs when the derivative of the network is so small, that the update to the original weights is almost negligible. In order to overcome this issue, WGANs are recommended. 
  • Non-convergence: It occurs when the network is unable to converge to a global minima. This results from unstable training. This issue can be tackled with spectral normalization. You can read about Spectral Normalization here

Check also

Vanishing and Exploding Gradients in Neural Network Models: Debugging, Monitoring, and Fixing

Variations of GAN

Since the release of the first GAN, there have been many variants of GANs. Below are some of the most popular GANs:

  • CycleGAN
  • StyleGAN
  • PixelRNN
  • Text2image
  • DiscoGAN
  • IsGAN

See also

6 GAN Architectures You Really Should Know

Understanding GAN Loss Functions

This article solely focuses on Pix2Pix GAN. In the following section, we will understand some of the key components of the same like the architecture, loss function et cetera. 

What is the Pix2Pix GAN?

Pix2Pix GAN is a conditional GAN (cGAN) that was developed by Phillip Isola, et al. Unlike vanilla GAN which uses only real data and noise to learn and generate images, cGAN uses real data, noise as well as labels to generate images. 

In essence, the generator learns the mapping from the real data as well as the noise. 

What Is the Pix2Pix GAN?

Similarly, the discriminator also learns representation from labels as well as real data. 

What Is the Pix2Pix GAN?

This setting enables cGAN to be suitable for image-to-image translation tasks, where the generator is conditioned on an input image to generate a corresponding output image. In other words, the generator uses a condition distribution (or data) such a guidance or a blueprint to generate a target image (see the image below).  

Pix2Pix is a conditional GAN
Pix2Pix is a conditional GAN | Source: Author
Application of Pix2Pix
Application of Pix2Pix | Source

The idea with Pix2Pix relies on the dataset provided for the training. It is a pair to pair image translation with training examples {x, y} having a correspondence between them. 

Pix2Pix network architectures

The pix2pix has two important architectures, one for generator and the other for discriminator namely U-net and patchGAN. Let’s explore both of them in much detail. 

U-Net generator 

As mentioned before, the architecture used in pix2pix is called U-net. U-net was primarily developed for biomedical image segmentation by Ronneberger et. al. in 2015. 

U-Net Generator
U-Net generator | Source

UNet consists of two major parts: 

  1. A contracting path made up of convolutional layers (left side) which downsamples the data while extracting information. 
  2. An expansive path made of up transpose convolution layer (right side) which upsamples the information. 

Let’s say if our downsampling has three convolutional layers C_l(1,2,3), then we have to make sure that our upsampling has three transpose convolutional layers C_u(1,2,3). This is because we want to connect the corresponding blocks of the same sizes using a skip connection

Skip connection
Skip connection | Source: Author

Downsampling

During downsampling, each convolutional block extracts spatial information and passes the information to the next convolutional block to extract more information until it reaches the middle part known as the bottleneck. Upsampling starts from the bottleneck. 

Upsampling

During upsampling, each transpose convolutional block expands information from the previous block while concatenating the information from the corresponding downsampling block. By concatenating information, the network can then learn to assemble a more precise output based on this information.

This architecture is able to localize, i.e. it is able to find the object of interest pixel by pixel. Furthermore, UNet also allows the network to propagate context information from lower resolution to higher resolution layers. This allows the network to generate high-resolution samples. 

Markovian discriminator (PatchGAN)

The discriminator uses Patch GAN architecture. This architecture contains a number of transpose convolutional blocks. It takes an NxN part of the image and tries to find whether it is real or fake. N can be of any size. It can be smaller than the original image and it is still able to produce high-quality results. The discriminator is applied convolutionally across the whole image. Also, because the discriminator is smaller i.e. it has fewer parameters compared to the generator, it is effectively faster. 

PatchGAN can effectively model the image as a Markov random field, where NxN is considered as an independent patch. Therefore, PatchGAN can be understood as a form of texture/style loss.

Loss function

The loss function is: 

Loss function

The equation above has two components: one for the discriminator and the other for the generator. Let’s understand both of them one by one. 

In any GAN the discriminator is trained first in every iteration so that it can recognize both real and fake data so that it can discriminator or classify between them. Essentially, 

D(x,y) = 1 i.e. real and, 

D(x,G(z)) = 0 i.e. fake. 

It is worth noting that G(z) will also produce fake samples and thus its value will be closer to zero. In theory, the discriminator should always classify G(z) as zero only. So the discriminator should maintain a maximum distance between real and fake i.e. 1 and 0 in every iteration. In other words, the discriminator should maximize the loss function. 

After the discriminator, the generator is trained. The generator i.e. G(z) should learn to produce samples that are closer to the real samples. To learn the original distribution it takes help from the discriminator i.e. instead of D(x, G(z)) = 0, we change D(x, G(z)) = 1. 

With the alteration in labeling, the generator now optimizes its parameter with respect to the parameter belonging to the discriminator with ground truth labels. This step ensures that the generator can now yield samples that are close to real data i.e 1. 

The loss function is also mixed with an L1 loss so that the generator not only fools the discriminator but also produces images near to the ground truth. In essence, the loss function has an additional L1 loss for the generator. 

Loss function

Therefore, the final loss function is:

Loss function

It is worth noting that the L1 loss is able to preserve low-frequency details in the image, but it will not be able to capture high-frequency details. Hence, it will still produce blurry images. To tackle this problem PatchGAN is used. 

Optimization 

The optimization and training process is similar to vanilla GAN. But the training itself is a difficult process since the objective function of GAN is more of concave-concave rather than convex-concave. Because of this, it is difficult to find a saddle point and this is what makes training and optimizing the GANs difficult. 

As we saw previously that the generator is not trained directly but through the discriminator. This essentially limits the optimization of the generator. If the discriminator fails to capture high dimensional spaces then it is very certain that the generator will fail to produce good samples. On the other hand, if we are able to train the discriminator in a much more optimal way then we can be assured that the generator will be trained optimally as well. 

In the early stages of training, G is untrained and weak to produce good samples. This makes the discriminator very powerful. So instead of minimizing log(1 − D(G(z))) the generator is trained to maximize log D(G(z)). This creates some sort of stability in the early stages of the training. 

Other ways to tackle the instability are:

  1. Using spectral normalization in every layer of the model
  2. Using Wasserstein loss which calculates the average score for real or fake images.

To learn more refer to this article:

A Gentle Introduction to Generative Adversarial Network Loss Functions

Hands-on example with Pix2Pix

Let’s code Pix2Pix with PyTorch and get an intuitive understanding of how it works and the various components behind it. This section will give you a clear understanding of how the Pix2Pix works. 

Let’s start by downloading the data. The following code can be used to download the data.

!wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz
!tar -xvf facades.tar.gz

Data visualization

Once the data is downloaded we can then visualize them to understand what are the necessary steps needed to format the data according to the requirement. 

We will import the following libraries for data visualization.

import matplotlib.pyplot as plt
import cv2
import os
import numpy as np

path = '/content/facades/train/'
plt.imshow(cv2.imread(f'{path}91.jpg'))
Data visualization
Source: Author

From the image above we can see that the data has two images attached together. If we then see the shape of the image above we find that the width is 512, which means that the image can be easily separated into two. 

print('Shape of the image: ',cv2.imread(f'{path}91.jpg').shape)

>> Shape of the image:  (256, 512, 3)

To separate the images we will use the following commands:

#dividing the image by width
image = cv2.imread(f'{path}91.jpg')
w = image.shape[1]//2
image_real = image[:, :w, :]
image_cond = image[:, w:, :]
fig, axes = plt.subplots(1,2, figsize=(18,6))
axes[0].imshow(image_real, label='Real')
axes[1].imshow(image_cond, label='Condition')
plt.show()
Data visualization
Source: Author

The image on the left will be our ground truth while the image on the right will be our conditional image. We will refer to them as y and x respectively. 

Creating dataloader

Dataloader is a function that will allow us to format the data as per the PyTorch requirement. This will involve two steps: 

1. Formatting the data, that is to read the data from the source, cropping them followed by converting them to Pytorch tensors. 

class data(Dataset):
   def __init__(self, path='/content/facades/train/'):
       self.filenames = glob(path+'*.jpg')


   def __len__(self):
       return len(self.filenames)

   def __getitem__(self, idx):
       filename = self.filenames[idx]

       image = cv2.imread(filename)
       image_width = image.shape[1]
       image_width = image_width // 2
       real = image[:, :image_width, :]
       condition = image[:, image_width:, :]

       real = transforms.functional.to_tensor(real)
       condition = transforms.functional.to_tensor(condition)

       return real, condition

2. Loading the data by using Pytorch’s DataLoader function to create batches before feeding them into the neural nets. 

train_dataset = data()
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

val_dataset = data(path='/content/facades/val/')
val_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

Keep in mind that we will create a data loader for training and validation. 

Utils

In this section, we involve creating components that will be used to build the Generator and Discriminator. The components that we will create will be a convolutional function for downsampling and a transpose convolution function for upsampling which will be referred to as cnn_block and tcnn_block respectively. 

def cnn_block(in_channels,out_channels,kernel_size,stride=1,padding=0, first_layer = False):

   if first_layer:
       return nn.Conv2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding)
   else:
       return nn.Sequential(
           nn.Conv2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding),
           nn.BatchNorm2d(out_channels,momentum=0.1,eps=1e-5),
           )

def tcnn_block(in_channels,out_channels,kernel_size,stride=1,padding=0,output_padding=0, first_layer = False):
   if first_layer:
       return nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,output_padding=output_padding)

   else:
       return nn.Sequential(
           nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,output_padding=output_padding),
           nn.BatchNorm2d(out_channels,momentum=0.1,eps=1e-5),
           )

Generator

Now, let’s define the generator. We will use the two components to define the same. 

class Generator(nn.Module):
 def __init__(self,instance_norm=False):#input : 256x256
   super(Generator,self).__init__()
   self.e1 = cnn_block(c_dim,gf_dim,4,2,1, first_layer = True)
   self.e2 = cnn_block(gf_dim,gf_dim*2,4,2,1,)
   self.e3 = cnn_block(gf_dim*2,gf_dim*4,4,2,1,)
   self.e4 = cnn_block(gf_dim*4,gf_dim*8,4,2,1,)
   self.e5 = cnn_block(gf_dim*8,gf_dim*8,4,2,1,)
   self.e6 = cnn_block(gf_dim*8,gf_dim*8,4,2,1,)
   self.e7 = cnn_block(gf_dim*8,gf_dim*8,4,2,1,)
   self.e8 = cnn_block(gf_dim*8,gf_dim*8,4,2,1, first_layer=True)

   self.d1 = tcnn_block(gf_dim*8,gf_dim*8,4,2,1)
   self.d2 = tcnn_block(gf_dim*8*2,gf_dim*8,4,2,1)
   self.d3 = tcnn_block(gf_dim*8*2,gf_dim*8,4,2,1)
   self.d4 = tcnn_block(gf_dim*8*2,gf_dim*8,4,2,1)
   self.d5 = tcnn_block(gf_dim*8*2,gf_dim*4,4,2,1)
   self.d6 = tcnn_block(gf_dim*4*2,gf_dim*2,4,2,1)
   self.d7 = tcnn_block(gf_dim*2*2,gf_dim*1,4,2,1)
   self.d8 = tcnn_block(gf_dim*1*2,c_dim,4,2,1, first_layer = True)#256x256
   self.tanh = nn.Tanh()

 def forward(self,x):
   e1 = self.e1(x)
   e2 = self.e2(F.leaky_relu(e1,0.2))
   e3 = self.e3(F.leaky_relu(e2,0.2))
   e4 = self.e4(F.leaky_relu(e3,0.2))
   e5 = self.e5(F.leaky_relu(e4,0.2))
   e6 = self.e6(F.leaky_relu(e5,0.2))
   e7 = self.e7(F.leaky_relu(e6,0.2))
   e8 = self.e8(F.leaky_relu(e7,0.2))
   d1 = torch.cat([F.dropout(self.d1(F.relu(e8)),0.5,training=True),e7],1)
   d2 = torch.cat([F.dropout(self.d2(F.relu(d1)),0.5,training=True),e6],1)
   d3 = torch.cat([F.dropout(self.d3(F.relu(d2)),0.5,training=True),e5],1)
   d4 = torch.cat([self.d4(F.relu(d3)),e4],1)
   d5 = torch.cat([self.d5(F.relu(d4)),e3],1)
   d6 = torch.cat([self.d6(F.relu(d5)),e2],1)
   d7 = torch.cat([self.d7(F.relu(d6)),e1],1)
   d8 = self.d8(F.relu(d7))

   return self.tanh(d8)

Discriminator

Let’s define the discriminator using the downsampling function. 

class Discriminator(nn.Module):
 def __init__(self,instance_norm=False):#input : 256x256
   super(Discriminator,self).__init__()
   self.conv1 = cnn_block(c_dim*2,df_dim,4,2,1, first_layer=True) # 128x128
   self.conv2 = cnn_block(df_dim,df_dim*2,4,2,1)# 64x64
   self.conv3 = cnn_block(df_dim*2,df_dim*4,4,2,1)# 32 x 32
   self.conv4 = cnn_block(df_dim*4,df_dim*8,4,1,1)# 31 x 31
   self.conv5 = cnn_block(df_dim*8,1,4,1,1, first_layer=True)# 30 x 30

   self.sigmoid = nn.Sigmoid()
 def forward(self, x, y):
   O = torch.cat([x,y],dim=1)
   O = F.leaky_relu(self.conv1(O),0.2)
   O = F.leaky_relu(self.conv2(O),0.2)
   O = F.leaky_relu(self.conv3(O),0.2)
   O = F.leaky_relu(self.conv4(O),0.2)
   O = self.conv5(O)

   return self.sigmoid(O)

Defining parameters

In this section, we will define the parameters. These parameters will help us in training the neural network. 

# Define parameters
batch_size = 4
workers = 2

epochs = 30

gf_dim = 64
df_dim = 64

L1_lambda = 100.0

in_w = in_h = 256
c_dim = 3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Initializing the models

Let’s initialize both models and enable CUDA for faster training. 

G = Generator().to(device)
D = Discriminator().to(device)

We will also define the optimizers and the loss function. 

G_optimizer = optim.Adam(G.parameters(), lr=2e-4,betas=(0.5,0.999))
D_optimizer = optim.Adam(D.parameters(), lr=2e-4,betas=(0.5,0.999))

bce_criterion = nn.BCELoss()
L1_criterion = nn.L1Loss()

Training

Once all the important functions are defined we will then initialize the training loop. 

for ep in range(epochs):
 for i, data in enumerate(train_loader):

   y, x = data
   x = x.to(device)
   y = y.to(device)

   b_size = x.shape[0]

   real_class = torch.ones(b_size,1,30,30).to(device)
   fake_class = torch.zeros(b_size,1,30,30).to(device)

   #Train D
   D.zero_grad()

   real_patch = D(y,x)
   real_gan_loss=bce_criterion(real_patch,real_class)

   fake=G(x)

   fake_patch = D(fake.detach(),x)
   fake_gan_loss=bce_criterion(fake_patch,fake_class)

   D_loss = real_gan_loss + fake_gan_loss
   D_loss.backward()
   D_optimizer.step()


   #Train G
   G.zero_grad()
   fake_patch = D(fake,x)
   fake_gan_loss=bce_criterion(fake_patch,real_class)

   L1_loss = L1_criterion(fake,y)
   G_loss = fake_gan_loss + L1_lambda*L1_loss
   G_loss.backward()

   G_optimizer.step()

   #Neptune logging
   run["Gen Loss"].log(G_loss.item())
   run["Dis Loss"].log(D_loss.item())
   run['L1 Loss'].log(L1_loss.item())
   run['Gen GAN Loss'].log(fake_gan_loss.item())
   # Log PyTorch model weights
   torch.save(G.state_dict(), 'PIX2PIX.ckpt')
   run['model_checkpoints'].upload('PIX2PIX.ckpt')

   if (i+1)%5 == 0 :
     print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f},D(real): {:.2f}, D(fake):{:.2f},g_loss_gan:{:.4f},g_loss_L1:{:.4f}'
           .format(ep, epochs, i+1, len(train_loader), D_loss.item(), G_loss.item(),real_patch.mean(), fake_patch.mean(),fake_gan_loss.item(),L1_loss.item()))
     G_losses.append(G_loss.item())
     D_losses.append(D_loss.item())
     G_GAN_losses.append(fake_gan_loss.item())
     G_L1_losses.append(L1_loss.item())

     with torch.no_grad():
       G.eval()
       fake = G(fixed_x).detach().cpu()
       G.train()
     figs=plt.figure(figsize=(10,10))
     plt.subplot(1,3,1)
     plt.axis("off")
     plt.title("conditional image (x)")
     plt.imshow(np.transpose(vutils.make_grid(fixed_x, nrow=1,padding=5, normalize=True).cpu(),(1,2,0)))

     plt.subplot(1,3,2)
     plt.axis("off")
     plt.title("fake image")
     plt.imshow(np.transpose(vutils.make_grid(fake, nrow=1,padding=5, normalize=True).cpu(),(1,2,0)))

     plt.subplot(1,3,3)
     plt.axis("off")
     plt.title("ground truth (y)")
     plt.imshow(np.transpose(vutils.make_grid(fixed_y, nrow=1,padding=5, normalize=True).cpu(),(1,2,0)))

     plt.savefig(os.path.join('./','pix2pix'+"-"+str(ep) +".png"))

     run['Results'].log(File(f'pix2pix-{str(ep)}.png'))
     plt.close()
     img_list.append(figs)

Monitoring our model

Training the model is not the last step. You need to monitor the training and track it to analyze the performance and implement changes if necessary. Given how taxing it can get to monitor the performance of a GAN with too many losses, plots, and metrics to deal with, we will use Neptune at this step.

Neptune allows the user to:

  • 1 Monitor the live performance of the model
  • 2 Monitor the performance of the hardware
  • 3 Store and compare different metadata for different runs (like metrics, parameters, performance, data, etc.)
  • 4 Share the work with others

To get started, just follow these steps:

1. Install neptune-client using pip install neptune or conda install -c conda-forge neptune on your local system.

!pip install neptune

2. Create an account and log into neptune.ai.

3. Once you’re logged in, create a New Project.

4. Now, you can log different metadata to Neptune. Read more about it here.

For this project, we will log our parameters into the Neptune dashboard. For logging the parameters or any iformation into the dashboard create the dictionary.

#Logging parameter in Neptune
PARAMS = {'Epoch': epochs,
         'Batch Size': batch_size,
         'Input Channels': c_dim,

         'Workers': workers,
         'Optimizer': 'Adam',
         'Learning Rate': 2e-4,
         'Metrics': ['Binary Cross Entropy', 'L1 Loss'],
         'Activation': ['Leaky Relu', 'Tanh', 'Sigmoid' ],
         'Device': device}

Once the dictionary is created we will log them using the following command:

run['parameters'] = PARAMS

Keep in mind that the loss, generated images, and the model’s weights are all logged into the Neptune dashboard using the `run` command. 

For instance, in the training above you will find the following command:

#Neptune logging
   run["Gen Loss"].log(G_loss.item())
   run["Dis Loss"].log(D_loss.item())
   run['L1 Loss'].log(L1_loss.item())
   run['Gen GAN Loss'].log(fake_gan_loss.item())

These are essentially used to log data into the Neptune dashboard. 

Learn more

Explore the project in Neptune

Once the training is initialized all the logged information will automatically log into the dashboard. Neptune fetches live information from the training which allows live monitoring of the whole process

Below are the screenshots of the monitoring process. 

Monitoring the performance of the model
Monitoring the performance of the model | Source
Monitoring the performance of the hardware
Monitoring the performance of the hardware | Source

You can also access all metadata and see generated samples.

Access to all metadata
Access to all metadata | Source
Access to the generated samples
Access to the generated samples | Source

Finally, you can compare the metadata from different runs. This is useful e.g. when you want to see if after tuning some parameters, your model performs better than the previous one.

Pix2Pix compare runs
Comparing metadata from different runs | Source

Key takeaways

  • Pix2Pix is a conditional GAN that uses images and labels to generate images. 
  • It uses two architectures:
    • U-Net for generator
    • PatchGAN for discriminator
  • PatchGAN uses smaller patches of  NxN size in the generated image to discriminate it from real or fake instead of discriminating the entire image at once. 
  • Pix2Pix has an additional loss specifically for the generator so that it can generate images closer to the ground truth. 
  • Pix2Pix is a pairwise image translation algorithm. 

Other GANs that you can explore are:

  1. CycleGAN: It is similar to Pix2Pix since most of the approach is the same except the data part. Instead of pair-image-translation it is unpaired-image-translation. Learning and exploring CycleGAN will be much easier since it was developed by the same authors.
  2. If you are interested in text-to-image translation then you should explore :
    • DeepDaze
    • BigSleep
    • DALL:E
  3. Other interesting GANs projects you may want to try out:
    • StyleGAN
    • AnimeGAN
    • BigGAN
    • Age-cGAN
    • StarGAN

References

  1. Image-to-Image Translation with Conditional Adversarial Networks
  2. Deep Learning Book: Ian Goodfellow 
  3. Generative Adversarial Networks: Goodfellow et al.
  4. Generative Adversarial Networks and Some of GAN Applications – Everything You Need to Know
  5. A Gentle Introduction to Generative Adversarial Network Loss Functions.