Neptune Blog

Pix2pix: Key Model Architecture Decisions

10 min
22nd April, 2025

Generative Adversarial Networks or GANs is a type of neural network that belongs to the class of unsupervised learning models. It is used for the task of deep generative modeling. 

In deep generative modeling, deep neural networks learn a probability distribution over a given set of data points and generate similar ones. Since it is an unsupervised learning task, it uses no labels during the learning process. 

Since its release in 2014, the deep learning community has been actively developing new GANs to improve the field of generative modeling. This article aims to provide information on GANs, specifically Pix2Pix GAN, one of the most used generative models.

What is GAN?

GAN was designed by Ian Goodfellow in 2014. GAN’s main intention was to generate samples that were not blurry and had rich representations of features. Discriminative models were doing well on this front as they were able to classify between different classes. Deep generative models, on the other hand, were far less effective due to the difficulty in approximating many intractable probabilistic computations, which are quite evident in Autoencoders.  

Autoencoders and their variants are explicit likelihood models, meaning they explicitly compute the probability density function over a given distribution. GANs and their variants are implicit likelihood models, which means they don’t compute the probability density function but rather learn the underlying distribution. 

GANs learn the underlying distribution by approaching the whole problem as a binary classification problem. In this approach, the problem model is presented by two models: a generator and a discriminator. The job of the generator is to generate new samples and the job of the discriminator is to classify or discriminate if the sample produced by the generator is real or fake. 

The two models are trained together in a zero-sum game until the generator can produce samples that are similar to the real samples. In other words, they are trained until the generator can fool the discriminator. 

Architecture of a vanilla GAN

Let’s briefly understand the architecture of GANs. From this section onward most of the topics will be explained using code. So to begin with, let’s install and import all the required dependencies:

pip install torch torchvision matplotlib cv2 numpy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

Generator

The Generator is a component in GAN that takes in noise, which by definition follows a Gaussian distribution, and yields samples similar to the original dataset. As GANs have evolved over the years, they have adopted the use of CNNs, which is quite prominent in computer vision tasks. But for simplicity, we will define it with just linear functions using Pytorch.

class Generator(nn.Module):
   def __init__(self, z_dim, img_dim):
       super().__init__()
       self.gen = nn.Sequential(
           nn.Linear(z_dim, 256),
           nn.LeakyReLU(0.01),
           nn.Linear(256, img_dim),
           nn.Tanh(),  # normalize inputs to [-1, 1] to make outputs [-1, 1]
       )

   def forward(self, x):
       return self.gen(x)

Discriminator

The discriminator is simply a classifier that classifies whether the data yielded by the generator is real or fake. It does this by learning the original distribution from the real data and then evaluating between the two. We will keep things simple and define the discriminator using linear functions. 

class Discriminator(nn.Module):
   def __init__(self, in_features):
       super().__init__()
       self.disc = nn.Sequential(
           nn.Linear(in_features, 128),
           nn.LeakyReLU(0.01),
           nn.Linear(128, 1),
           nn.Sigmoid(),
       )

   def forward(self, x):
       return self.disc(x)

The key difference between the generator and the discriminator is the last layer. The former yields the same shape as that of the image while the latter yields only one output, either 0 or 1. 

Loss function and training

The loss function is one of the most important components in any deep learning algorithm. For instance, if we design a CNN to minimize the Euclidean distance between the ground truth, and predicted results, it will tend to produce blurry results. This is because Euclidean distance is minimized by averaging all plausible outputs, which cause blurring. 

The above point is an important one that we must keep in mind. With that being said, the loss function that we will use for vanilla GAN will be binary cross-entropy loss or BCELoss because we are performing binary classification. 

criterion = nn.BCELoss()

Now let’s define the optimization method and other related parameters:

# Define hyperparameters
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1  # 784
batch_size = 32
num_epochs = 100

# Initialize Generator and Discriminator
gen = Generator(z_dim, image_dim).to(device)
disc = Discriminator(image_dim).to(device)

# Set up optimizers
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)

# Prepare dataset and dataloader
transform = transforms.Compose([
   transforms.ToTensor(),
   transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST(root="dataset/", transform=transform, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Set up TensorBoard writers
from torch.utils.tensorboard import SummaryWriter
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0

Let’s understand the training loop. The training loop of GAN starts with:

  1. Generating samples from the generator using Gaussian distribution 
  2. Training the discriminator using real data and fake data produced by the generator
  3. Updating the discriminator
  4. Updating the generator

Here’s what the training loop looks like:

for epoch in range(num_epochs):
  # Loop over each batch of data
  for batch_idx, (real, _) in enumerate(loader):
      real = real.view(-1, 784).to(device)
      batch_size = real.shape[0]
      ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
# Generate random noise as input to the generator
      noise = torch.randn(batch_size, z_dim).to(device)
      # Produce fake images from the generator using the random noise
      fake = gen(noise)
      # Get discriminator's predictions on real images
      disc_real = disc(real).view(-1)
      lossD_real = criterion(disc_real, torch.ones_like(disc_real))
      # Get discriminator's predictions on fake images
disc_fake = disc(fake).view(-1)
      # Calculate discriminator's loss on fake images
lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
      # Combine real and fake loss to get total discriminator loss
lossD = (lossD_real + lossD_fake) / 2
# Backpropagation and optimization step for discriminator
      disc.zero_grad()
      lossD.backward(retain_graph=True)
      opt_disc.step()
      ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
      # where the second option of maximizing doesn't suffer from
      # saturating gradients
      output = disc(fake).view(-1)
      lossG = criterion(output, torch.ones_like(output))
# Backpropagation and optimization step for generator
      gen.zero_grad()
      lossG.backward()
      opt_gen.step()
      if batch_idx == 0:
          print(
              f"""Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)}
                    Loss D: {lossD:.4f}, loss G: {lossG:.4f}"""
          )
          with torch.no_grad():
              # fixed_noise is not defined earlier in the code
              # It should be a constant noise vector used to generate consistent samples
              # Let's define it at the beginning of the training loop
              if 'fixed_noise' not in locals():
                  fixed_noise = torch.randn(64, z_dim, device=device)
              fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
              data = real.reshape(-1, 1, 28, 28)
              img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
              img_grid_real = torchvision.utils.make_grid(data, normalize=True)
              writer_fake.add_image(
                  "Mnist Fake Images", img_grid_fake, global_step=step
              )
              writer_real.add_image(
                  "Mnist Real Images", img_grid_real, global_step=step
              )
              step += 1

There are some important considerations from the loop above:

  1. The loss function for the discriminator is calculated twice: one for real images and another for fake images.
    • For real images, the ground truth is converted to ones using the torch.ones_like function which returns a matrix of ones of a defined shape.  
    • For fake images, the ground truth is converted to ones using the torch.zeros_like function which returns a matrix of zeros of a defined shape.  
  2. The loss function for the generator is calculated only once. If you observe carefully, it is the same loss function that is used by the discriminator to calculate the loss for fake images. The only difference is that instead of using the torch.zeros_like function, torch.ones_like function is used. The interchanging of labels from 0 to 1 enables the generator to learn representations that will produce real images, therefore fooling the discriminator. 

Mathematically, we can define the whole process as:

 This equation represents the objective function of a Generative Adversarial Network (GAN). In a GAN, the generator G aims to produce realistic data to fool the discriminator D, while the discriminator tries to distinguish between real data x and generated data G(z), where \( z is random noise. The discriminator D maximizes its ability to classify real vs. generated samples, while the generator G minimizes the discriminator’s success, creating a minimax game. The function  V(D, G) is optimized by minimizing over G and maximizing over D.
 This equation represents the objective function of a Generative Adversarial Network (GAN). In a GAN, the generator G aims to produce realistic data to fool the discriminator D, while the discriminator tries to distinguish between real data x and generated data G(z), where \( z is random noise. The discriminator D maximizes its ability to classify real vs. generated samples, while the generator G minimizes the discriminator’s success, creating a minimax game. The function  V(D, G) is optimized by minimizing over G and maximizing over D.

Application of GANs

GANs are widely used for:

  • Generating training samples: GANs are often used to generate samples for specific tasks like the classification of malignant and benign cancer cells, especially where the data is scarce to train a classifier. 
  • AI Art or Generative Art: AI or Generative art is another new domain where GANs are extensively used. Since the introduction of non-fungible tokens, artists all over the world have been creating art in unorthodox fashion i.e. digital and generative. GANs like DeepDaze, BigSleep, BigGAN, CLIP, VQGAN, etc. are the most commonly used by creators. 
AI Art or Generative Art: This figure has been generated with AI
AI Art or Generative Art: This figure has been generated with AI | Source: Author
  • Image-to-image translation: The idea here is to translate a certain type of image to an image in the target domain. For example a day-light image into a night image, or a winter image to a summer image (see the image below). GANs like pix2pix, cycleGAN, styleGAN are few of the most popular GANs used by digital creators.
Image-to-image translation example. To the left, the original image (a car in a road during winter). To the right, the same image in the summer, generated by AI
Image-to-image translation example. To the left, the original image (a car in a road during winter). To the right, the same image in the summer, generated by AI | Source
  • Text-to-image translation: Text-to-image translation is simply converting a text or a given string into an image. This is a very populardomain as of now and it is a growing community. As mentioned previously GANs such as DeepDaze, BigSleep and DALL·E from OpenAI are the most common tools for this.
Text-to-image translation. The prompt “an armchair in the shape of an avocado” turns into images of the desired chair in different angles.
Text-to-image translation. The prompt “an armchair in the shape of an avocado” turns into images of the desired chair in different angles. | Source

Issues with GANs

Although GANs can produce images from random Gaussian distributions that are similar to real images, this process is not perfect most of the time. Here’s why:

  • Mode Collapse:  This refers to the issue when the generator can fool the discriminator by learning with fewer data samples from the overall data. Because of mode collapse, the GAN is not able to learn a wide variety of distributions and remains limited to a few. 
  • Diminished gradient: Diminished or vanishing gradient descent occurs when the derivative of the network is so small, that the update to the original weights is almost negligible. To overcome this issue, Wasserstein GANs (WGANs in short) are recommended. 
  • Non-convergence: It occurs when the network is unable to converge to a global minimum. This results from unstable training, and it can be tackled with spectral normalization.

Variations of GAN

Since the release of the first GAN, there have been many variants of GANs. Below are some of the most popular GANs:

  • CycleGAN
  • StyleGAN
  • PixelRNN
  • Text2image
  • DiscoGAN
  • IsGAN

This article solely focuses on Pix2Pix GAN. In the following section, we will understand some of the key components of the same like the architecture, loss function, etcetera. 

What is the Pix2Pix GAN?

Pix2Pix GAN is a conditional GAN (cGAN) that was developed by Phillip Isola, et al. Unlike vanilla GAN which uses only real data and noise to learn and generate images, cGAN uses real data andnoise as well as labels to generate images. 

In essence, the generator learns the mapping from the real data as well as the noise. 

What Is the Pix2Pix GAN?

The generator G combines the learnt real data x and the random noise z to output y, which is the fake data. 

Similarly, the discriminator not only learns from the “real data” example it has seen, but also from the labels that help it understand what is real and what is fake. 

What Is the Pix2Pix GAN?

The discriminator uses, then, two sources of information to improve its ability to tell real from fake: x (the real data) and y (the label saying “real” or “fake.”)

This setting makes cGAN to be suitable for image-to-image translation tasks, where the generator is conditioned on an input image to generate the corresponding output image. In other words, the generator uses a condition distribution (or data) such as a guide or a blueprint to generate a target image (see the image below).   

The model generates realistic building facades (right column) based on input segmentation maps (left column), with comparisons to the actual ground truth images (center column)
The model generates realistic building facades (right column) based on input segmentation maps (left column), with comparisons to the actual ground truth images (center column) | Source: Author
Applications of Pix2Pix, a type of conditional GANs
Applications of Pix2Pix, a type of conditional GANs | Source

The idea with Pix2Pix relies on the dataset provided for the training. It is a pair-to-pair image translation with training examples {x, y} having a correspondence between them. 

Pix2Pix network architectures

The pix2pix has two important architectures, one for the generator and the other for the discriminator, namely U-net and patchGAN. Let’s explore both of them in more detail. 

U-Net generator 

As mentioned before, the architecture used in pix2pix is called U-net. U-net was primarily developed for biomedical image segmentation by Ronneberger et. al. in 2015. 

U-Net generator:  A symmetric encoder-decoder structure with down-sampling through max pooling (red arrows) and up-sampling via transposed convolutions (green arrows). Skip connections (gray arrows) connect layers of matching spatial dimensions in the encoder and decoder, preserving spatial information for segmentation in the output map.
U-Net generator:  A symmetric encoder-decoder structure with down-sampling through max pooling (red arrows) and up-sampling via transposed convolutions (green arrows). Skip connections (gray arrows) connect layers of matching spatial dimensions in the encoder and decoder, preserving spatial information for segmentation in the output map. | Source

U-Net consists of two major parts: 

  1. A contracting path made up of convolutional layers (left side) which downsamples the data while extracting information. 
  2. An expansive path made of up transpose convolution layer (right side) which upsamples the information. 

Let’s say our downsampling has three convolutional layers C_l(1,2,3), then we have to make sure that our upsampling has three transpose convolutional layers C_u(1,2,3). This is because we want to connect the corresponding blocks of the same sizes using a skip connection. 

Skip connection architecture: This diagram illustrates the use of skip layers between encoder (C_l1, C_l2, C_l3) and decoder (C_u1, C_u2, C_u3) blocks, with a bottleneck in the center to keep the feature dimensions at each stage. This retains the spatial details across the network
Skip connection architecture: This diagram illustrates the use of skip layers between encoder (C_l1, C_l2, C_l3) and decoder (C_u1, C_u2, C_u3) blocks, with a bottleneck in the center to keep the feature dimensions at each stage. This retains the spatial details across the network | Source: Author

Downsampling

During downsampling, each convolutional block extracts spatial information and passes the information to the next convolutional block to extract more information until it reaches the middle part known as the bottleneck. Upsampling starts from the bottleneck. 

Upsampling

During upsampling, each transpose convolutional block expands information from the previous block while concatenating the information from the corresponding downsampling block. By concatenating information, the network can then learn to assemble a more precise output based on this information.

This architecture can localize, i.e. it can find the object of interest pixel by pixel. Furthermore, U-Net also allows the network to propagate context information from lower resolution to higher resolution layers. This allows the network to generate high-resolution samples. 

Markovian discriminator (PatchGAN)

The discriminator uses PatchGAN architecture. This architecture contains several transposed convolutional blocks. It takes an NxN part of the image and tries to find whether it is real or fake. N can be of any size. It can be smaller than the original image and it is still able to produce high-quality results. The discriminator is applied convolutionally across the whole image. Also, because the discriminator is smaller i.e. it has fewer parameters compared to the generator, it is faster. 

PatchGAN can effectively model the image as a Markov random field, where NxN is considered an independent patch. Therefore, PatchGAN can be understood as a form of texture/style loss.

Loss function

The loss function is: 

Loss function

The equation above has two components: one for the discriminator and the other for the generator. Let’s understand both of them one by one. 

In any GAN, the discriminator is trained first in every iteration so that it can recognize both real and fake data. Essentially, 

D(x,y) = 1 i.e. real and, 

D(x,G(z)) = 0 i.e. fake. 

It is worth noting that G(z) will also produce fake samples and thus its value will be closer to zero. In theory, the discriminator should always classify G(z) as zero only. Therefore the discriminator should maintain a maximum distance between real and fake i.e. 1 and 0 in every iteration. In other words, the discriminator should maximize the loss function. 

After the discriminator, the generator is trained. The generator i.e. G(z) should learn to produce samples that are closer to the real samples. To learn the original distribution it takes help from the discriminator i.e. instead of D(x, G(z)) = 0, we change D(x, G(z)) = 1. 

With the alteration in labeling, the generator now optimizes its parameter concerning the parameter belonging to the discriminator with ground truth labels. This step ensures that the generator can now yield samples that are close to real data i.e. 1. 

The loss function is also mixed with an L1 loss so that the generator not only fools the discriminator but also produces images near the ground truth. In essence, the loss function has an additional L1 loss for the generator. 

Loss function

Therefore, the final loss function is:

Loss function

It is worth noting that the L1 loss can preserve low-frequency details in the image, but it will not be able to capture high-frequency details. Hence, it will still produce blurry images. To tackle this problem, PatchGAN is used. 

Optimization 

The optimization and training process is similar to vanilla GAN. However, the training itself is a difficult process since the objective function of GAN is more concave-concave rather than convex-concave. Because of this, it is difficult to find a saddle point and this is what makes training and optimizing the GANs difficult. 

As we saw previously, the generator is not trained directly but through the discriminator. This essentially limits the optimization of the generator. If the discriminator fails to capture high dimensional spaces then it is very certain that the generator will fail to produce good samples. On the other hand, if we can train the discriminator in a much more optimal way then we can be assured that the generator will be trained optimally as well. 

In the early stages of training, G is untrained and weak to produce good samples. This makes the discriminator very powerful, so instead of minimizing log(1 − D(G(z))), the generator is trained to maximize log D(G(z)). This creates some sort of stability in the early stages of the training. 

Other ways to tackle the instability are:

  1. Using spectral normalization in every layer of the model
  2. Using Wasserstein loss which calculates the average score for real or fake images.

Hands-on example with Pix2Pix

Let’s implement Pix2Pix with PyTorch and get an intuitive understanding of how it works and the various components behind it. This section will give you a clear understanding of how the Pix2Pix algorithm works. 

Let’s start by downloading the data. The following code can be used to download the data.

!wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz
!tar -xvf facades.tar.gz

Data visualization

Once the data is downloaded, we can then visualize them to understand what are the necessary steps needed to format the data according to the requirement. 

We will import the following libraries for data visualization.

import matplotlib.pyplot as plt
import cv2
import os
import numpy as np

path = "facades/train/"
plt.imshow(cv2.imread(f"{path}91.jpg"))
Resulting output from the previous code
Resulting output from the previous code | Source: Author

From the image above, we can see that the data has two images attached together. If we then see the shape of the image above we find that the width is 512, which means that the image can be easily separated into two. 

print('Shape of the image: ',cv2.imread(f'{path}91.jpg').shape)

>> Shape of the image:  (256, 512, 3)

To separate the images we will use the following commands:

# Dividing the image by width
image = cv2.imread(f'{path}91.jpg')
w = image.shape[1]//2
image_real = image[:, :w, :]
image_cond = image[:, w:, :]
fig, axes = plt.subplots(1,2, figsize=(18,6))
axes[0].imshow(image_real, label='Real')
axes[1].imshow(image_cond, label='Condition')
plt.show()

Output:

Data visualization
Resulting output | Source: Author

The image on the left will be our ground truth while the image on the right will be our conditional image. We will refer to them as y and x respectively (from left to right). 

Creating dataloader

Dataloader is a function that will allow us to format the data as per the PyTorch requirement. This will involve two steps: 

1. Formatting the data, that is reading the data from the source, cropping them followed by converting them to Pytorch tensors. 

from glob import glob
from torch.utils.data import Dataset


class Data(Dataset):
   def __init__(self, path="facades/train/"):
       self.filenames = glob(path + "*.jpg")

   def __len__(self):
       return len(self.filenames)

   def __getitem__(self, idx):
       filename = self.filenames[idx]

       image = cv2.imread(filename)
       image_width = image.shape[1]
       image_width = image_width // 2
       real = image[:, :image_width, :]
       condition = image[:, image_width:, :]

       real = transforms.functional.to_tensor(real)
       condition = transforms.functional.to_tensor(condition)

       return real, condition

2. Loading the data by using Pytorch’s DataLoader function to create batches before feeding them into the neural nets. 

train_dataset = Data()
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

val_dataset = Data(path="facades/val/")
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=True)

Keep in mind that we will create a data loader for training and validation. 

Utils

In this section, we involve creating components that will be used to build the Generator and Discriminator. The components that we will create will be a convolutional function for downsampling and a transpose convolution function for upsampling which will be referred to as cnn_block and tcnn_block respectively. 

def cnn_block(
   in_channels, out_channels, kernel_size, stride = 1, padding = 0, first_layer = False
):

   if first_layer:
       return nn.Conv2d(
           in_channels, out_channels, kernel_size, stride = stride, padding = padding
       )
   else:
       return nn.Sequential(
           nn.Conv2d(
               in_channels, out_channels, kernel_size, stride = stride, padding = padding
           ),
           nn.BatchNorm2d(out_channels, momentum = 0.1, eps = 1e-5),
       )


def tcnn_block(
   in_channels,
   out_channels,
   kernel_size,
   stride = 1,
   padding = 0,
   output_padding = 0,
   first_layer = False,
):
   if first_layer:
       return nn.ConvTranspose2d(
           in_channels,
           out_channels,
           kernel_size,
           stride = stride,
           padding = padding,
           output_padding = output_padding,
       )

   else:
       return nn.Sequential(
           nn.ConvTranspose2d(
               in_channels,
               out_channels,
               kernel_size,
               stride = stride,
               padding = padding,
               output_padding = output_padding,
           ),
           nn.BatchNorm2d(out_channels, momentum = 0.1, eps = 1e-5),
       )

Defining parameters

In this section, we will define the parameters. These parameters will help us in training the neural network. 

# Define parameters
batch_size = 4
workers = 2

epochs = 30

gf_dim = 64
df_dim = 64

L1_lambda = 100.0

in_w = in_h = 256
c_dim = 3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Generator

Now, let’s define the generator. We will use the two components to define the same. 

class Generator(nn.Module):

class Generator(nn.Module):
 def __init__(self,instance_norm=False): #input : 256x256
   super(Generator,self).__init__()
   self.e1 = cnn_block(c_dim,gf_dim,4,2,1, first_layer = True)
   self.e2 = cnn_block(gf_dim,gf_dim*2,4,2,1,)
   self.e3 = cnn_block(gf_dim*2,gf_dim*4,4,2,1,)
   self.e4 = cnn_block(gf_dim*4,gf_dim*8,4,2,1,)
   self.e5 = cnn_block(gf_dim*8,gf_dim*8,4,2,1,)
   self.e6 = cnn_block(gf_dim*8,gf_dim*8,4,2,1,)
   self.e7 = cnn_block(gf_dim*8,gf_dim*8,4,2,1,)
   self.e8 = cnn_block(gf_dim*8,gf_dim*8,4,2,1, first_layer=True)

   self.d1 = tcnn_block(gf_dim*8,gf_dim*8,4,2,1)
   self.d2 = tcnn_block(gf_dim*8*2,gf_dim*8,4,2,1)
   self.d3 = tcnn_block(gf_dim*8*2,gf_dim*8,4,2,1)
   self.d4 = tcnn_block(gf_dim*8*2,gf_dim*8,4,2,1)
   self.d5 = tcnn_block(gf_dim*8*2,gf_dim*4,4,2,1)
   self.d6 = tcnn_block(gf_dim*4*2,gf_dim*2,4,2,1)
   self.d7 = tcnn_block(gf_dim*2*2,gf_dim*1,4,2,1)
   self.d8 = tcnn_block(gf_dim*1*2,c_dim,4,2,1, first_layer = True)#256x256
   self.tanh = nn.Tanh()

 def forward(self,x):
   e1 = self.e1(x)
   e2 = self.e2(F.leaky_relu(e1,0.2))
   e3 = self.e3(F.leaky_relu(e2,0.2))
   e4 = self.e4(F.leaky_relu(e3,0.2))
   e5 = self.e5(F.leaky_relu(e4,0.2))
   e6 = self.e6(F.leaky_relu(e5,0.2))
   e7 = self.e7(F.leaky_relu(e6,0.2))
   e8 = self.e8(F.leaky_relu(e7,0.2))
   d1 = torch.cat([F.dropout(self.d1(F.relu(e8)),0.5,training=True),e7],1)
   d2 = torch.cat([F.dropout(self.d2(F.relu(d1)),0.5,training=True),e6],1)
   d3 = torch.cat([F.dropout(self.d3(F.relu(d2)),0.5,training=True),e5],1)
   d4 = torch.cat([self.d4(F.relu(d3)),e4],1)
   d5 = torch.cat([self.d5(F.relu(d4)),e3],1)
   d6 = torch.cat([self.d6(F.relu(d5)),e2],1)
   d7 = torch.cat([self.d7(F.relu(d6)),e1],1)
   d8 = self.d8(F.relu(d7))

   return self.tanh(d8)

Discriminator

Let’s define the discriminator using the downsampling function. 

class Discriminator(nn.Module):
 def __init__(self,instance_norm=False):#input : 256x256
   super(Discriminator,self).__init__()
   self.conv1 = cnn_block(c_dim*2,df_dim,4,2,1, first_layer=True) # 128x128
   self.conv2 = cnn_block(df_dim,df_dim*2,4,2,1)# 64x64
   self.conv3 = cnn_block(df_dim*2,df_dim*4,4,2,1)# 32 x 32
   self.conv4 = cnn_block(df_dim*4,df_dim*8,4,1,1)# 31 x 31
   self.conv5 = cnn_block(df_dim*8,1,4,1,1, first_layer=True)# 30 x 30

   self.sigmoid = nn.Sigmoid()
 def forward(self, x, y):
   O = torch.cat([x,y],dim=1)
   O = F.leaky_relu(self.conv1(O),0.2)
   O = F.leaky_relu(self.conv2(O),0.2)
   O = F.leaky_relu(self.conv3(O),0.2)
   O = F.leaky_relu(self.conv4(O),0.2)
   O = self.conv5(O)

   return self.sigmoid(O)

Initializing the models

Let’s initialize both models and enable CUDA for faster training. 

G = Generator().to(device)
D = Discriminator().to(device)

We will also define the optimizers and the loss function. 

G_optimizer = optim.Adam(G.parameters(), lr=2e-4,betas=(0.5,0.999))
D_optimizer = optim.Adam(D.parameters(), lr=2e-4,betas=(0.5,0.999))

bce_criterion = nn.BCELoss()
L1_criterion = nn.L1Loss()

Training and monitoring our model

Training the model is not the last step. You need to monitor the training and track it to analyze the performance and implement changes if necessary. Given how taxing it can get to monitor the performance of a GAN with too many losses, plots, and metrics to deal with, we will use neptune.ai at this step.

Neptune allows the user to:

  1. Monitor the live performance of the model
  2. Monitor the performance of the hardware
  3. Store and compare different metadata for different runs (like metrics, parameters, performance, data, etc.)
  4. Share the work with others

Disclaimer

Please note that this article references a deprecated version of Neptune.

For information on the latest version with improved features and functionality, please visit our website.

To get started, just follow these steps:

1. Install the Python neptune library on your local system:

!pip install neptune

2. Sign up at neptune.ai.

3. Create a project for storing your metadata.

4. Save your credentials as environment variables.

For this project, we will log our parameters into the Neptune dashboard. For logging the parameters or any information into the dashboard, you can use a run object:

import neptune
import os

run = neptune.init_run(
   project=os.getenv(“NEPTUNE_PROJECT_NAME”),
   api_token=os.getenv("NEPTUNE_API_TOKEN")
)

A run object establishes a connection between your environment and the project’s dashboard you’ve created for this tutorial. To log metadata, like the dictionary below, you can use the following syntax:

# Logging parameter in Neptune
PARAMS = {'Epoch': epochs,
         'Batch Size': batch_size,
         'Input Channels': c_dim,

         'Workers': workers,
         'Optimizer': 'Adam',
         'Learning Rate': 2e-4,
         'Metrics': ['Binary Cross Entropy', 'L1 Loss'],
         'Activation': ['Leaky Relu', 'Tanh', 'Sigmoid' ],
         'Device': device}

run['parameters'] = PARAMS

To log the loss, generated images, and the model’s weights, we will use the run object again but with different methods like append or upload. Here is our training loop putting together everything we have along with Neptune logging:

# Define missing variables
epochs = 30  # Adjust as needed
L1_lambda = 100  # Adjust as needed
G_losses = []
D_losses = []
G_GAN_losses = []
G_L1_losses = []
img_list = []

# Assuming fixed_x and fixed_y are not defined, let's create them
fixed_x, fixed_y = next(iter(train_loader))
fixed_x = fixed_x.to(device)
fixed_y = fixed_y.to(device)

for ep in range(30):
    for i, data in enumerate(train_loader):

        y, x = data
        x = x.to(device)
        y = y.to(device)

        b_size = x.shape[0]

        real_class = torch.ones(b_size, 1, 30, 30).to(device)
        fake_class = torch.zeros(b_size, 1, 30, 30).to(device)

        # Train D
        D.zero_grad()

        real_patch = D(y, x)
        real_gan_loss = bce_criterion(real_patch, real_class)

        fake = G(x)

        fake_patch = D(fake.detach(), x)
        fake_gan_loss = bce_criterion(fake_patch, fake_class)

        D_loss = real_gan_loss + fake_gan_loss
        D_loss.backward()
        D_optimizer.step()

        # Train G
        G.zero_grad()
        fake_patch = D(fake, x)
        fake_gan_loss = bce_criterion(fake_patch, real_class)

        L1_loss = L1_criterion(fake, y)
        G_loss = fake_gan_loss + L1_lambda * L1_loss
        G_loss.backward()

        G_optimizer.step()

        # Neptune logging
        run["Gen Loss"].append(G_loss.item())
        run["Dis Loss"].append(D_loss.item())
        run["L1 Loss"].append(L1_loss.item())
        run["Gen GAN Loss"].append(fake_gan_loss.item())

        if (i + 1) % 5 == 0:
            print(
                "Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f},D(real): {:.2f}, D(fake):{:.2f},g_loss_gan:{:.4f},g_loss_L1:{:.4f}".format(
                    ep,
                    epochs,
                    i + 1,
                    len(train_loader),
                    D_loss.item(),
                    G_loss.item(),
                    real_patch.mean(),
                    fake_patch.mean(),
                    fake_gan_loss.item(),
                    L1_loss.item(),
                )
            )
            G_losses.append(G_loss.item())
            D_losses.append(D_loss.item())
            G_GAN_losses.append(fake_gan_loss.item())
            G_L1_losses.append(L1_loss.item())

            with torch.no_grad():
                G.eval()
                fake = G(fixed_x).detach().cpu()
                G.train()
            figs = plt.figure(figsize=(10, 10))
            plt.subplot(1, 3, 1)
            plt.axis("off")
            plt.title("conditional image (x)")
            plt.imshow(
                np.transpose(
                    torchvision.utils.make_grid(fixed_x.cpu(), nrow=1, padding=5, normalize=True),
                    (1, 2, 0),
                )
            )

            plt.subplot(1, 3, 2)
            plt.axis("off")
            plt.title("fake image")
            plt.imshow(
                np.transpose(
                    torchvision.utils.make_grid(fake, nrow=1, padding=5, normalize=True),
                    (1, 2, 0),
                )
            )

            plt.subplot(1, 3, 3)
            plt.axis("off")
            plt.title("ground truth (y)")
            plt.imshow(
                np.transpose(
                    torchvision.utils.make_grid(fixed_y.cpu(), nrow=1, padding=5, normalize=True),
                    (1, 2, 0),
                )
            )

            run["epoch_results"].upload(figs)
            plt.close()
            img_list.append(figs)
run.stop()

Once the training is initialized, all the logged information will automatically log into the dashboard. Neptune fetches live information from the training which allows live monitoring of the entire process

Below are the screenshots of the monitoring process. 

See in the app
Monitoring the performance of the model

You can also access all metadata and see generated samples.

See in the app
Access to all metadata in the project

Switching to the images panel, it will show you the generated samples:

See in the app
Access to the generated samples

Key takeaways

  • Pix2Pix is a conditional GAN that uses images and labels to generate images. 
  • It uses two architectures:
    • U-Net for generator
    • PatchGAN for discriminator
  • PatchGAN uses smaller patches of  NxN size in the generated image to discriminate it from real or fake instead of discriminating the entire image at once. 
  • Pix2Pix has an additional loss specifically for the generator so that it can generate images closer to the ground truth. 
  • Pix2Pix is a pairwise image translation algorithm. 

Other GANs that you can explore are:

  1. CycleGAN: It is similar to Pix2Pix since most of the approach is the same except the data part. Instead of pair-image-translation, it is unpaired-image-translation. Learning and exploring CycleGAN will be much easier since it was developed by the same authors.
  2. If you are interested in text-to-image translation then you should explore:
    • DeepDaze: Uses a generative model to create images from text prompts. Great for generating abstract or artistic images based on text descriptions.
    • BigSleep: Great if you want to discover unusual visualizations from prompts.
    • DALL:E: Developed by OpenAI, this model generates creative compositions with high level of detail directly from text descriptions.
  3. Other interesting GAN projects you may want to try out:
    • StyleGAN: Generates realistic faces; ideal for style manipulation and creative blending.
    • AnimeGAN: Converts real photos into anime-style images.
    • BigGAN: Produces images with realistic textures.
    • Age-cGAN: Alters age in facial images.
    • StarGAN: Handles multiple transformations in faces, like hair color and expression changes.

Was the article useful?

    This site is protected by reCAPTCHA and the Google Privacy Policy and Terms of Service apply.